diff --git a/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java b/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java
index f02fe00..98fb525 100755
--- a/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java
+++ b/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java
@@ -77,8 +77,7 @@ public VectorSchemaRoot executeQuery(String query) throws Exception {
try {
ArrowResultSet resultSet = executeQueryBatched(query);
VectorSchemaRoot result = resultSet.nextBatch();
- resultSet.close();
- logger.info("Successfully executed query and got results");
+ logger.info("Successfully executed query and got results. Remember to close the vector root to release memory.");
return result;
} catch (Exception e) {
logger.severe("Error executing query: " + e.getMessage());
diff --git a/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml b/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml
new file mode 100755
index 0000000..bb53872
--- /dev/null
+++ b/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt b/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt
index 47dacfa..c225c03 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt
@@ -101,6 +101,7 @@ set(SOURCE_FILES
src/SQLGetStmtAttr.cpp
src/SQLSetDescField.cpp
src/SQLGetTypeInfo.cpp
+ src/SQLBindParameter.cpp
)
# Header files
@@ -110,6 +111,8 @@ set(HEADER_FILES
include/globals.hpp
include/logging.hpp
include/statement.hpp
+ include/JniParam.hpp
+ src/JniParam.cpp
)
# Function to filter files starting with `._`
@@ -125,6 +128,8 @@ endfunction()
filter_files(SOURCE_FILES)
filter_files(HEADER_FILES)
+add_definitions(-D_SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING)
+
# Create the library
add_library(${PROJECT_NAME} SHARED
${SOURCE_FILES}
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp
new file mode 100755
index 0000000..c9496a7
--- /dev/null
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp
@@ -0,0 +1,63 @@
+//
+// Created by kennethstott on 11/26/2024.
+//
+
+#ifndef JNIPARAM_H
+#define JNIPARAM_H
+
+#endif //JNIPARAM_H
+
+#pragma once
+
+#include
+#include
+#include
+
+class JniParam {
+public:
+ enum class Type {
+ String,
+ StringArray,
+ Integer,
+ Float,
+ Double,
+ Boolean
+ };
+
+ // Constructors for different types
+ explicit JniParam(const std::string& value);
+ explicit JniParam(const std::vector& value);
+ explicit JniParam(int value);
+ explicit JniParam(float value);
+ explicit JniParam(double value);
+ explicit JniParam(bool value);
+
+ JniParam();
+
+ // Get the JNI signature for this type
+ std::string getSignature() const;
+
+ // Convert the parameter to JNI value
+ jvalue toJValue(JNIEnv* env) const;
+
+ // Clean up any JNI resources
+ void cleanup(JNIEnv* env, const jvalue& value) const;
+
+ // Get the parameter type
+ [[nodiscard]] Type getType() const { return type_; }
+ [[nodiscard]] std::string getString() const { return stringValue_; }
+ [[nodiscard]] std::vector getStringArray() const { return stringArrayValue_; }
+ [[nodiscard]] int getInt() const { return intValue_; }
+ [[nodiscard]] float getFloat() const { return floatValue_; }
+ [[nodiscard]] double getDouble() const { return doubleValue_; }
+ [[nodiscard]] bool getBool() const { return boolValue_; }
+
+private:
+ Type type_;
+ std::string stringValue_;
+ std::vector stringArrayValue_;
+ int intValue_ = 0;
+ float floatValue_ = 0.0f;
+ double doubleValue_ = 0.0;
+ bool boolValue_ = false;
+};
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp
index 1dfd875..50b0dc7 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp
@@ -1,93 +1,120 @@
#pragma once
-
-// Windows includes must come first
-#define WIN32_LEAN_AND_MEAN
+#include
#include
+#define WIN32_LEAN_AND_MEAN
-// SQL includes
+// ODBC includes
#include
+// JNI includes
+#include
// Standard includes
+#include
#include
+#include
-
-// JNI includes
-#include
-
+#include "JniParam.hpp"
#include "statement.hpp"
// Forward declarations
-class Environment;
+class Statement;
+
+struct ConnectionParams {
+ std::string server;
+ std::string port;
+ std::string database;
+ std::string role;
+ std::string auth;
+ std::string uid;
+ std::string pwd;
+ std::string encrypt = "no";
+ std::string timeout;
+
+ bool isValid() const {
+ return !server.empty() && !port.empty() && !database.empty();
+ }
+};
+
+// struct ColumnDesc {
+// std::string name;
+// SQLSMALLINT sqlType;
+// SQLULEN columnSize;
+// SQLSMALLINT decimalDigits;
+// SQLSMALLINT nullable;
+// };
class Connection {
-private:
- JavaVM* jvm{};
- jobject wrapperInstance{};
- jclass wrapperClass{};
- std::string connectionString;
- bool isConnected{};
-
- struct ConnectionParams {
- std::string server;
- std::string port;
- std::string database;
- std::string role;
- std::string auth;
- std::string uid;
- std::string pwd;
- std::string encrypt;
- std::string timeout;
-
- [[nodiscard]] bool isValid() const {
- return !server.empty() && !port.empty() && !database.empty();
- }
- };
+public:
+ Connection() = default;
+ ~Connection();
- bool initJVM();
- bool initWrapper(const ConnectionParams& params);
+ // Delete copy constructor and assignment operator
+ Connection(const Connection&) = delete;
+ Connection& operator=(const Connection&) = delete;
- static bool parseConnectionString(const std::string& connStr, ConnectionParams& params);
+ // Connection string methods
+ void setConnectionString(const std::string& dsn, const std::string& uid, const std::string& authStr);
+ void setConnectionString(const std::string& dsn);
- static std::string buildJdbcUrl(const ConnectionParams ¶ms);
- std::vector activeStmts;
+ // Connection management
+ SQLRETURN connect();
+ SQLRETURN disconnect();
-public:
+ // Query methods
+ SQLRETURN Query(const std::string& query, Statement* stmt);
+ SQLRETURN GetTables(
+ const std::string& catalogName,
+ const std::string& schemaName,
+ const std::string& tableName,
+ const std::string& tableType,
+ Statement* stmt) const;
+ SQLRETURN GetColumns(
+ const std::string& catalogName,
+ const std::string& schemaName,
+ const std::string& tableName,
+ const std::string& columnName,
+ Statement* stmt) const;
+
+ // Statement management
bool hasActiveStmts() const;
-
void cleanupActiveStmts();
- SQLRETURN GetTables(const std::string &catalogName, const std::string &schemaName, const std::string &tableName,
- const std::string &tableType, Statement *stmt) const;
-
- SQLRETURN GetColumns(const std::string &catalogName, const std::string &schemaName, const std::string &tableName,
- const std::string &columnName, Statement *stmt) const;
-
- Connection() = default;
- ~Connection();
-
- SQLRETURN connect();
+ // Get connection state
+ bool isConnected() const { return connected; }
+ JNIEnv* env = nullptr;
+ [[nodiscard]] const std::string& getConnectionString() const { return connectionString; }
- void setConnectionString(const std::string &dsn, const std::string &uid, const std::string &authStr);
- void setConnectionString(const std::string &dsn);
+private:
+ // Connection state
+ bool connected = false;
+ std::string connectionString;
+ std::vector activeStmts;
+ // JVM/JNI state
+ JavaVM* jvm = nullptr;
+ jclass wrapperClass = nullptr;
+ jobject wrapperInstance = nullptr;
+ // Initialization helpers
+ bool initJVM();
+ bool initWrapper(const ConnectionParams& params);
+ bool parseConnectionString(const std::string& connStr, ConnectionParams& params);
+ std::string buildJdbcUrl(const ConnectionParams& params);
+ SQLRETURN populateColumnDescriptors(jobject schemaRoot, Statement *stmt) const;
- // Delete copy constructor and assignment operator
- Connection(const Connection&) = delete;
- Connection& operator=(const Connection&) = delete;
-
- SQLRETURN disconnect();
- JNIEnv* env{};
+ // Result set handling
+ SQLRETURN executeAndGetArrowResult(
+ const char *methodName,
+ const std::vector ¶ms,
+ Statement *stmt) const;
- // Getters
- [[nodiscard]] JavaVM* getJVM() const { return jvm; }
- [[nodiscard]] jobject getWrapperInstance() const { return wrapperInstance; }
- [[nodiscard]] jclass getWrapperClass() const { return wrapperClass; }
- [[nodiscard]] const std::string& getConnectionString() const { return connectionString; }
+ // Type mapping helpers
+ static SQLSMALLINT mapArrowTypeToSQL(JNIEnv* env, jobject arrowType);
+ static SQLULEN getSQLTypeSize(SQLSMALLINT sqlType);
};
// Helper functions declarations
-std::string GetModuleDirectory(HMODULE hModule);
+std::string GetModuleDirectory();
std::string WideStringToString(const std::wstring& wstr);
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp
index 0034a45..61d2ce9 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp
@@ -1,5 +1,7 @@
#pragma once
+#include
#include
+#define WIN32_LEAN_AND_MEAN
#include
#include
#include
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp
index 798ce37..51c5aed 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp
@@ -7,6 +7,8 @@
#include
#include // Add this if you're using JNI types
+#include "JniParam.hpp"
+
// Forward declaration of Connection class
class Connection;
@@ -17,6 +19,27 @@ struct ColumnDesc {
SQLULEN columnSize;
SQLSMALLINT decimalDigits;
SQLSMALLINT nullable;
+ const char* catalogName;
+ SQLSMALLINT catalogNameLength;
+ const char* schemaName;
+ SQLSMALLINT schemaNameLength;
+ const char* tableName;
+ SQLSMALLINT tableNameLength;
+ const char* baseColumnName;
+ SQLSMALLINT baseColumnNameLength;
+ const char* baseTableName;
+ SQLSMALLINT baseTableNameLength;
+ const char* literalPrefix;
+ SQLSMALLINT literalPrefixLength;
+ const char* literalSuffix;
+ SQLSMALLINT literalSuffixLength;
+ const char* localTypeName;
+ SQLSMALLINT localTypeNameLength;
+ SQLSMALLINT unnamed;
+ const char* label;
+ SQLSMALLINT labelLength;
+ SQLULEN displaySize;
+ SQLSMALLINT scale;
};
struct ColumnData {
@@ -25,9 +48,11 @@ struct ColumnData {
};
class Statement {
-public:
- std::vector setupColumnResultColumns();
+private:
+ std::vector boundParams;
+ std::string originalQuery;
+public:
explicit Statement(Connection* connection);
// Delete copy constructor and assignment operator
@@ -36,11 +61,17 @@ class Statement {
void clearResults();
- std::vector setupTableResultColumns();
+ SQLRETURN bindParameter(SQLUSMALLINT parameterNumber, SQLSMALLINT inputOutputType, SQLSMALLINT valueType,
+ SQLSMALLINT parameterType, SQLULEN columnSize, SQLSMALLINT decimalDigits,
+ SQLPOINTER parameterValuePtr, SQLLEN bufferLength, SQLLEN *strLen_or_IndPtr);
+
+ std::string escapeString(const std::string &str) const;
+ std::string buildInterpolatedQuery() const;
SQLRETURN setArrowResult(jobject schemaRoot, const std::vector &columnDescriptors);
+ SQLRETURN setOriginalQuery(const std::string &query) { originalQuery = query; return SQL_SUCCESS; }
SQLRETURN getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
- SQLPOINTER targetValue, SQLLEN bufferLength,
- SQLLEN* strLengthOrIndicator);
+ SQLPOINTER targetValue, SQLLEN bufferLength,
+ SQLLEN* strLengthOrIndicator);
SQLRETURN fetch();
[[nodiscard]] SQLRETURN getFetchStatus() const;
[[nodiscard]] bool hasData() const;
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def b/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def
index 4ee45f1..166dcd4 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def
@@ -10,8 +10,8 @@ EXPORTS
SQLDisconnect
SQLGetInfo=SQLGetInfo_A
SQLGetInfoW=SQLGetInfo_W
- SQLExecDirect=SQLExecDirect_A
- SQLExecDirectW=SQLExecDirect_W
+ SQLExecDirect
+ SQLExecDirectW
SQLSetDescField
SQLSetDescFieldW
SQLTables=SQLTables_A
@@ -27,6 +27,7 @@ EXPORTS
SQLColAttributeW
SQLFetch
SQLGetTypeInfo
+ SQLBindParameter
SQLGetTypeInfoW
SQLGetData=SQLGetData_A
SQLGetDataW=SQLGetData_W
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp
new file mode 100755
index 0000000..65f2d88
--- /dev/null
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp
@@ -0,0 +1,136 @@
+#include "../include/JniParam.hpp"
+#include "../include/logging.hpp"
+
+JniParam::JniParam(const std::string& value)
+ : type_(Type::String)
+ , stringValue_(value)
+{}
+
+JniParam::JniParam(const std::vector& value)
+ : type_(Type::StringArray)
+ , stringArrayValue_(value)
+{}
+
+JniParam::JniParam(int value)
+ : type_(Type::Integer)
+ , intValue_(value)
+{}
+
+JniParam::JniParam(float value)
+ : type_(Type::Float)
+ , floatValue_(value)
+{}
+
+JniParam::JniParam(double value)
+ : type_(Type::Double)
+ , doubleValue_(value)
+{}
+
+JniParam::JniParam(bool value)
+ : type_(Type::Boolean)
+ , boolValue_(value)
+{}
+
+JniParam::JniParam() {}
+
+std::string JniParam::getSignature() const {
+ switch (type_) {
+ case Type::String:
+ return "Ljava/lang/String;";
+ case Type::StringArray:
+ return "[Ljava/lang/String;";
+ case Type::Integer:
+ return "I";
+ case Type::Float:
+ return "F";
+ case Type::Double:
+ return "D";
+ case Type::Boolean:
+ return "Z";
+ default:
+ return "";
+ }
+}
+
+jvalue JniParam::toJValue(JNIEnv* env) const {
+ jvalue val{};
+ try {
+ switch (type_) {
+ case Type::String:
+ val.l = stringValue_.empty() ?
+ nullptr :
+ env->NewStringUTF(stringValue_.c_str());
+ LOGF("Created jstring from: %s", stringValue_.c_str());
+ break;
+
+ case Type::StringArray: {
+ jclass stringClass = env->FindClass("java/lang/String");
+ if (!stringClass) {
+ LOG("Failed to find String class");
+ break;
+ }
+
+ jobjectArray arr = env->NewObjectArray(
+ stringArrayValue_.size(), stringClass, nullptr);
+ if (!arr) {
+ LOG("Failed to create String array");
+ env->DeleteLocalRef(stringClass);
+ break;
+ }
+
+ for (size_t i = 0; i < stringArrayValue_.size(); i++) {
+ jstring str = env->NewStringUTF(stringArrayValue_[i].c_str());
+ if (str) {
+ env->SetObjectArrayElement(arr, i, str);
+ env->DeleteLocalRef(str);
+ }
+ }
+
+ val.l = arr;
+ env->DeleteLocalRef(stringClass);
+ LOGF("Created String array with %zu elements", stringArrayValue_.size());
+ break;
+ }
+
+ case Type::Integer:
+ val.i = intValue_;
+ LOGF("Set integer value: %d", intValue_);
+ break;
+
+ case Type::Float:
+ val.f = floatValue_;
+ LOGF("Set float value: %f", floatValue_);
+ break;
+
+ case Type::Double:
+ val.d = doubleValue_;
+ LOGF("Set double value: %f", doubleValue_);
+ break;
+
+ case Type::Boolean:
+ val.z = boolValue_;
+ LOGF("Set boolean value: %d", boolValue_);
+ break;
+ }
+ }
+ catch (const std::exception& e) {
+ LOGF("Exception in toJValue: %s", e.what());
+ // Ensure val is zeroed in case of error
+ val = jvalue{};
+ }
+ return val;
+}
+
+void JniParam::cleanup(JNIEnv* env, const jvalue& value) const {
+ try {
+ if (type_ == Type::String || type_ == Type::StringArray) {
+ if (value.l != nullptr) {
+ env->DeleteLocalRef(value.l);
+ LOG("Cleaned up JNI reference");
+ }
+ }
+ }
+ catch (const std::exception& e) {
+ LOGF("Exception in cleanup: %s", e.what());
+ }
+}
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp
new file mode 100755
index 0000000..3bc5a9d
--- /dev/null
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp
@@ -0,0 +1,42 @@
+#include
+#include
+#define WIN32_LEAN_AND_MEAN
+#include "../include/connection.hpp"
+#include "../include/logging.hpp"
+//#include "../include/httplib.h"
+#include "../include/globals.hpp"
+#include "../include/environment.hpp"
+#include "../include/statement.hpp"
+
+extern "C" {
+ SQLRETURN SQL_API SQLBindParameter(
+ SQLHSTMT StatementHandle,
+ SQLUSMALLINT ParameterNumber,
+ SQLSMALLINT InputOutputType,
+ SQLSMALLINT ValueType,
+ SQLSMALLINT ParameterType,
+ SQLULEN ColumnSize,
+ SQLSMALLINT DecimalDigits,
+ SQLPOINTER ParameterValuePtr,
+ SQLLEN BufferLength,
+ SQLLEN *StrLen_or_IndPtr
+ ) {
+ if (!StatementHandle) {
+ return SQL_ERROR;
+ }
+
+ Statement* stmt = reinterpret_cast(StatementHandle);
+
+ return stmt->bindParameter(
+ ParameterNumber,
+ InputOutputType,
+ ValueType,
+ ParameterType,
+ ColumnSize,
+ DecimalDigits,
+ ParameterValuePtr,
+ BufferLength,
+ StrLen_or_IndPtr
+ );
+ }
+}
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp
index d16650d..cecb160 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp
@@ -10,222 +10,244 @@
// Column attribute functions
extern "C" {
-
SQLRETURN SQL_API SQLColAttribute(
- SQLHSTMT StatementHandle,
- SQLUSMALLINT ColumnNumber,
- SQLUSMALLINT FieldIdentifier,
- SQLPOINTER CharacterAttribute,
- SQLSMALLINT BufferLength,
- SQLSMALLINT* StringLength,
- SQLLEN* NumericAttribute) {
-
- LOGF("SQLColAttribute called - Column: %u, Field: %u", ColumnNumber, FieldIdentifier);
- auto* stmt = static_cast(StatementHandle);
- if (!stmt) {
- LOG("Invalid statement handle");
- return SQL_INVALID_HANDLE;
- }
-
- if (!stmt->hasResult) {
- LOG("No result set available");
- return SQL_ERROR;
- }
-
- LOGF("Accessing column %u of %zu", ColumnNumber, stmt->resultColumns.size());
- if (ColumnNumber <= 0 || ColumnNumber > stmt->resultColumns.size()) {
- LOG("Invalid column number");
- return SQL_ERROR;
- }
-
- const auto& col = stmt->resultColumns[ColumnNumber - 1];
- LOGF("Column name: %s", col.name);
-
- switch (FieldIdentifier) {
+ SQLHSTMT StatementHandle,
+ SQLUSMALLINT ColumnNumber,
+ SQLUSMALLINT FieldIdentifier,
+ SQLPOINTER CharacterAttribute,
+ SQLSMALLINT BufferLength,
+ SQLSMALLINT *StringLength,
+ SQLLEN *NumericAttribute) {
+ auto *stmt = static_cast(StatementHandle);
+ LOGF("SQLColAttribute called - Column: %u, Field: %u, Total Size: %u", ColumnNumber, FieldIdentifier, stmt->resultColumns.size());
+ if (!stmt) {
+ LOG("Invalid statement handle");
+ return SQL_INVALID_HANDLE;
+ }
+
+ if (!stmt->hasResult) {
+ LOG("No result set available");
+ return SQL_ERROR;
+ }
+
+ LOGF("Accessing column %u of %zu", ColumnNumber, stmt->resultColumns.size());
+ if (ColumnNumber <= 0 || ColumnNumber > stmt->resultColumns.size()) {
+ LOG("Invalid column number");
+ return SQL_ERROR;
+ }
+
+ const auto &col = stmt->resultColumns[ColumnNumber - 1];
+
+ switch (FieldIdentifier) {
// Basic column metadata
- case SQL_COLUMN_COUNT: // 0
+ case SQL_COLUMN_COUNT: // 0
if (NumericAttribute) {
*NumericAttribute = stmt->resultColumns.size();
}
break;
- case SQL_COLUMN_NAME: // 1
- case SQL_DESC_NAME: // SQL_COLUMN_NAME
- case SQL_DESC_LABEL: // SQL_COLUMN_LABEL
+ case SQL_COLUMN_NAME: // 1
+ case SQL_DESC_NAME: // SQL_COLUMN_NAME
+ if (CharacterAttribute && BufferLength > 0) {
+ LOGF("Returning column name: %s, %u", col.name, col.nameLength);
+ strncpy((char *) CharacterAttribute, col.name, BufferLength);
+ if (StringLength) {
+ *StringLength = col.nameLength;
+ }
+ } else {
+ LOG("Did not return column name");
+ }
+ break;
+
+ case SQL_DESC_LABEL: // SQL_COLUMN_LABEL
if (CharacterAttribute && BufferLength > 0) {
- strncpy((char*)CharacterAttribute, col.name, BufferLength);
+ strncpy((char *) CharacterAttribute, col.label, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(col.name));
+ *StringLength = col.labelLength;
}
}
break;
// Type information
- case SQL_COLUMN_TYPE: // 2
- case SQL_DESC_TYPE: // 1002
- // case SQL_DESC_CONCISE_TYPE: // SQL_COLUMN_TYPE
+ case SQL_COLUMN_TYPE: // 2
+ case SQL_DESC_TYPE: // 1002
+ // case SQL_DESC_CONCISE_TYPE: // SQL_COLUMN_TYPE
if (NumericAttribute) {
*NumericAttribute = col.sqlType;
}
break;
// Size and length information
- case SQL_COLUMN_LENGTH: // 3
- case SQL_DESC_LENGTH: // SQL_COLUMN_LENGTH
- case SQL_DESC_OCTET_LENGTH: // 1013
+ case SQL_COLUMN_LENGTH: // 3
+ case SQL_DESC_LENGTH: // SQL_COLUMN_LENGTH
+ case SQL_DESC_OCTET_LENGTH: // 1013
if (NumericAttribute) {
*NumericAttribute = col.columnSize;
}
break;
- case SQL_COLUMN_DISPLAY_SIZE: // 6
- case SQL_COLUMN_PRECISION: // 4
- case SQL_DESC_PRECISION: // SQL_COLUMN_PRECISION
- case SQL_COLUMN_SCALE: // 5
- case SQL_DESC_SCALE: // SQL_COLUMN_SCALE
+ case SQL_COLUMN_DISPLAY_SIZE: // 6
+ if (NumericAttribute) {
+ *NumericAttribute = col.displaySize;
+ }
+ break;
+
+ case SQL_COLUMN_PRECISION: // 4
+ case SQL_DESC_PRECISION: // SQL_COLUMN_PRECISION
if (NumericAttribute) {
- *NumericAttribute = 0; // VARCHAR doesn't have scale
+ *NumericAttribute = col.decimalDigits;
+ }
+ break;
+
+ case SQL_COLUMN_SCALE: // 5
+ case SQL_DESC_SCALE: // SQL_COLUMN_SCALE
+ if (NumericAttribute) {
+ *NumericAttribute = col.scale; // VARCHAR doesn't have scale
}
break;
// Nullability
- case SQL_COLUMN_NULLABLE: // 7
- case SQL_DESC_NULLABLE: // SQL_COLUMN_NULLABLE
+ case SQL_COLUMN_NULLABLE: // 7
+ case SQL_DESC_NULLABLE: // SQL_COLUMN_NULLABLE
if (NumericAttribute) {
*NumericAttribute = col.nullable;
}
break;
// Type characteristics
- case SQL_COLUMN_UNSIGNED: // 8
- // case SQL_DESC_UNSIGNED: // SQL_COLUMN_UNSIGNED
+ case SQL_COLUMN_UNSIGNED: // 8
+ // case SQL_DESC_UNSIGNED: // SQL_COLUMN_UNSIGNED
if (NumericAttribute) {
- *NumericAttribute = SQL_TRUE; // VARCHAR is unsigned
+ *NumericAttribute = SQL_TRUE; // VARCHAR is unsigned
}
break;
- case SQL_COLUMN_MONEY: // 9
- case SQL_COLUMN_AUTO_INCREMENT: // 11
- // case SQL_DESC_AUTO_UNIQUE_VALUE: // SQL_COLUMN_AUTO_INCREMENT
- if (NumericAttribute) {
- *NumericAttribute = SQL_FALSE; // VARCHAR is not auto-increment
- }
- break;
+ case SQL_COLUMN_MONEY: // 9
+ case SQL_COLUMN_AUTO_INCREMENT: // 11
+ // case SQL_DESC_AUTO_UNIQUE_VALUE: // SQL_COLUMN_AUTO_INCREMENT
+ if (NumericAttribute) {
+ *NumericAttribute = SQL_FALSE; // VARCHAR is not auto-increment
+ }
+ break;
- case SQL_COLUMN_UPDATABLE: // 10
- // case SQL_DESC_UPDATABLE: // SQL_COLUMN_UPDATABLE
+ case SQL_COLUMN_UPDATABLE: // 10
+ // case SQL_DESC_UPDATABLE: // SQL_COLUMN_UPDATABLE
if (NumericAttribute) {
- *NumericAttribute = SQL_ATTR_READONLY; // Catalog results are read-only
+ *NumericAttribute = SQL_ATTR_READONLY; // Catalog results are read-only
}
break;
- case SQL_COLUMN_CASE_SENSITIVE: // 12
- // case SQL_DESC_CASE_SENSITIVE: // SQL_COLUMN_CASE_SENSITIVE
+ case SQL_COLUMN_CASE_SENSITIVE: // 12
+ // case SQL_DESC_CASE_SENSITIVE: // SQL_COLUMN_CASE_SENSITIVE
if (NumericAttribute) {
- *NumericAttribute = SQL_TRUE; // VARCHAR is case sensitive
+ *NumericAttribute = SQL_TRUE; // VARCHAR is case sensitive
}
break;
- case SQL_COLUMN_SEARCHABLE: // 13
- // case SQL_DESC_SEARCHABLE: // SQL_COLUMN_SEARCHABLE
+ case SQL_COLUMN_SEARCHABLE: // 13
+ // case SQL_DESC_SEARCHABLE: // SQL_COLUMN_SEARCHABLE
if (NumericAttribute) {
- *NumericAttribute = SQL_SEARCHABLE; // VARCHAR is fully searchable
+ *NumericAttribute = SQL_SEARCHABLE; // VARCHAR is fully searchable
}
break;
// Type names and descriptions
- case SQL_COLUMN_TYPE_NAME: // 14
- // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME
+ case SQL_COLUMN_TYPE_NAME: // 14
+ // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME
if (CharacterAttribute && BufferLength > 0) {
- const char* typeName = "VARCHAR";
- strncpy((char*)CharacterAttribute, typeName, BufferLength);
+ strncpy((char *) CharacterAttribute, "VARCHAR", BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(typeName));
+ *StringLength = 7;
}
}
break;
// Table information
- case SQL_COLUMN_TABLE_NAME: // 15
- // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME
+ case SQL_COLUMN_TABLE_NAME: // 15
+ // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME
if (CharacterAttribute && BufferLength > 0) {
- const char* tableName = ""; // Catalog functions typically don't set this
- strncpy((char*)CharacterAttribute, tableName, BufferLength);
+ strncpy((char *) CharacterAttribute, col.tableName, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(tableName));
+ *StringLength = col.tableNameLength;
}
}
break;
- case SQL_COLUMN_OWNER_NAME: // 16
- // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME
+ case SQL_COLUMN_OWNER_NAME: // 16
+ // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME
if (CharacterAttribute && BufferLength > 0) {
- const char* schemaName = ""; // Catalog functions typically don't set this
- strncpy((char*)CharacterAttribute, schemaName, BufferLength);
+ strncpy((char *) CharacterAttribute, col.schemaName, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(schemaName));
+ *StringLength = col.schemaNameLength;
}
}
break;
- case SQL_COLUMN_QUALIFIER_NAME: // 17
- // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME
+ case SQL_COLUMN_QUALIFIER_NAME: // 17
+ // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME
if (CharacterAttribute && BufferLength > 0) {
- const char* catalogName = ""; // Catalog functions typically don't set this
- strncpy((char*)CharacterAttribute, catalogName, BufferLength);
+ strncpy((char *) CharacterAttribute, col.catalogName, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(catalogName));
+ *StringLength = col.catalogNameLength;
}
}
break;
// SQL literal formatting
- case SQL_DESC_LITERAL_PREFIX: // 27
+ case SQL_DESC_LITERAL_PREFIX: // 27
if (CharacterAttribute && BufferLength > 0) {
- const char* prefix = "'"; // VARCHAR uses single quotes
- strncpy((char*)CharacterAttribute, prefix, BufferLength);
+ strncpy((char *) CharacterAttribute, col.literalPrefix, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(prefix));
+ *StringLength = col.literalPrefixLength;
}
}
break;
- case SQL_DESC_LITERAL_SUFFIX: // 28
+ case SQL_DESC_LITERAL_SUFFIX: // 28
if (CharacterAttribute && BufferLength > 0) {
- const char* suffix = "'"; // VARCHAR uses single quotes
- strncpy((char*)CharacterAttribute, suffix, BufferLength);
+ strncpy((char *) CharacterAttribute, col.literalSuffix, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(suffix));
+ *StringLength = col.literalSuffixLength;
}
}
break;
- // // Additional descriptors
- // case SQL_DESC_FIXED_PREC_SCALE: // 1108
- // if (NumericAttribute) {
- // *NumericAttribute = SQL_FALSE; // VARCHAR is not fixed precision
- // }
- // break;
-
- case SQL_DESC_LOCAL_TYPE_NAME: // 29
+ case SQL_DESC_LOCAL_TYPE_NAME: // 29
if (CharacterAttribute && BufferLength > 0) {
- const char* localTypeName = "VARCHAR"; // Usually same as type name
- strncpy((char*)CharacterAttribute, localTypeName, BufferLength);
+ strncpy((char *) CharacterAttribute, col.localTypeName, BufferLength);
if (StringLength) {
- *StringLength = static_cast(strlen(localTypeName));
+ *StringLength = col.localTypeNameLength;
}
}
break;
- case SQL_DESC_NUM_PREC_RADIX: // 32
+ case SQL_DESC_NUM_PREC_RADIX: // 32
if (NumericAttribute) {
- *NumericAttribute = 0; // Not applicable for VARCHAR
+ *NumericAttribute = 0; // Not applicable for VARCHAR
}
break;
- case SQL_DESC_UNNAMED: // 1089
+ case SQL_DESC_UNNAMED: // 1089
if (NumericAttribute) {
- *NumericAttribute = SQL_NAMED; // Column has a name
+ *NumericAttribute = col.unnamed;
+ }
+ break;
+
+ case SQL_DESC_BASE_COLUMN_NAME: // 1025
+ if (CharacterAttribute && BufferLength > 0) {
+ strncpy((char *) CharacterAttribute, col.baseColumnName, BufferLength);
+ if (StringLength) {
+ *StringLength = col.baseColumnNameLength;
+ }
+ }
+ break;
+
+ case SQL_DESC_BASE_TABLE_NAME: // 1026
+ if (CharacterAttribute && BufferLength > 0) {
+ strncpy((char *) CharacterAttribute, col.baseTableName, BufferLength);
+ if (StringLength) {
+ *StringLength = col.baseTableNameLength;
+ }
}
break;
@@ -238,87 +260,88 @@ SQLRETURN SQL_API SQLColAttribute(
*StringLength = 0;
}
break;
- }
+ }
- return SQL_SUCCESS;
+ return SQL_SUCCESS;
}
SQLRETURN SQL_API SQLColAttributeW(
- SQLHSTMT StatementHandle,
- SQLUSMALLINT ColumnNumber,
- SQLUSMALLINT FieldIdentifier,
- SQLPOINTER CharacterAttribute,
- SQLSMALLINT BufferLength,
- SQLSMALLINT* StringLength,
- SQLLEN* NumericAttribute) {
-
- LOG("SQLColAttributeW called");
-
- // For non-string attributes, just delegate to the ANSI version
- switch (FieldIdentifier) {
- // String attributes that need Unicode conversion
- case SQL_COLUMN_NAME: // 1
- case SQL_COLUMN_TYPE_NAME: // 14
- case SQL_COLUMN_TABLE_NAME: // 15
- case SQL_COLUMN_OWNER_NAME: // 16
- case SQL_COLUMN_QUALIFIER_NAME: // 17
- case SQL_COLUMN_LABEL: // 18
- case SQL_DESC_NAME: // SQL_COLUMN_NAME
- // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME
- // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME
- // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME
- // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME
- // case SQL_DESC_LABEL: // SQL_COLUMN_LABEL
- case SQL_DESC_LITERAL_PREFIX: // 27
- case SQL_DESC_LITERAL_SUFFIX: // 28
- case SQL_DESC_LOCAL_TYPE_NAME: // 29
- break; // Handle these below with Unicode conversion
-
- // All other attributes can go directly to ANSI version
- default:
- return SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier,
- CharacterAttribute, BufferLength, StringLength,
- NumericAttribute);
- }
-
- // Get the ANSI string
- char ansiBuffer[SQL_MAX_MESSAGE_LENGTH];
- SQLSMALLINT ansiLength = 0;
-
- SQLRETURN ret = SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier,
- ansiBuffer, sizeof(ansiBuffer), &ansiLength, NumericAttribute);
-
- if (!SQL_SUCCEEDED(ret)) {
- return ret;
- }
-
- // Convert to wide char if we have a buffer
- if (SQL_SUCCEEDED(ret) && CharacterAttribute && BufferLength > 0) {
- if (ansiLength > 0) {
- size_t numChars = 0;
- mbstowcs_s(&numChars, (wchar_t*)CharacterAttribute,
- BufferLength/sizeof(wchar_t), ansiBuffer, _TRUNCATE);
- if (StringLength) {
- *StringLength = static_cast(numChars * sizeof(wchar_t));
- }
- LOGF("Converted string to Unicode. Original length: %d, Wide length: %zu",
- ansiLength, numChars);
- } else {
- // Empty string case
- if (BufferLength >= sizeof(wchar_t)) {
- *((wchar_t*)CharacterAttribute) = L'\0';
- }
- if (StringLength) {
- *StringLength = 0;
- }
- LOG("Set empty Unicode string");
- }
- } else if (StringLength) {
- // Just return required length if no buffer provided
- *StringLength = ansiLength * sizeof(wchar_t);
- LOGF("Returning required buffer length: %d bytes", *StringLength);
- }
-
- return ret;
+ SQLHSTMT StatementHandle,
+ SQLUSMALLINT ColumnNumber,
+ SQLUSMALLINT FieldIdentifier,
+ SQLPOINTER CharacterAttribute,
+ SQLSMALLINT BufferLength,
+ SQLSMALLINT *StringLength,
+ SQLLEN *NumericAttribute) {
+ LOG("SQLColAttributeW called");
+
+ // For non-string attributes, just delegate to the ANSI version
+ switch (FieldIdentifier) {
+ // String attributes that need Unicode conversion
+ case SQL_COLUMN_NAME: // 1
+ case SQL_COLUMN_TYPE_NAME: // 14
+ case SQL_COLUMN_TABLE_NAME: // 15
+ case SQL_COLUMN_OWNER_NAME: // 16
+ case SQL_COLUMN_QUALIFIER_NAME: // 17
+ case SQL_COLUMN_LABEL: // 18
+ case SQL_DESC_NAME: // SQL_COLUMN_NAME
+ // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME
+ // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME
+ // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME
+ // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME
+ // case SQL_DESC_LABEL: // SQL_COLUMN_LABEL
+ case SQL_DESC_LITERAL_PREFIX: // 27
+ case SQL_DESC_LITERAL_SUFFIX: // 28
+ case SQL_DESC_LOCAL_TYPE_NAME: // 29
+ case SQL_DESC_BASE_COLUMN_NAME: // 1025
+ case SQL_DESC_BASE_TABLE_NAME: // 1026
+ break; // Handle these below with Unicode conversion
+
+ // All other attributes can go directly to ANSI version
+ default:
+ return SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier,
+ CharacterAttribute, BufferLength, StringLength,
+ NumericAttribute);
+ }
+
+ // Get the ANSI string
+ char ansiBuffer[SQL_MAX_MESSAGE_LENGTH];
+ SQLSMALLINT ansiLength = 0;
+
+ SQLRETURN ret = SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier,
+ ansiBuffer, sizeof(ansiBuffer), &ansiLength, NumericAttribute);
+
+ if (!SQL_SUCCEEDED(ret)) {
+ return ret;
+ }
+
+ // Convert to wide char if we have a buffer
+ if (SQL_SUCCEEDED(ret) && CharacterAttribute && BufferLength > 0) {
+ if (ansiLength > 0) {
+ size_t numChars = 0;
+ errno_t err = mbstowcs_s(&numChars, (wchar_t *)CharacterAttribute,
+ BufferLength / sizeof(wchar_t), ansiBuffer, _TRUNCATE);
+ if (err == 0 && StringLength) {
+ *StringLength = static_cast(numChars * sizeof(wchar_t));
+ }
+ LOGF("Converted string to Unicode. Original length: %d, Wide length: %zu", ansiLength, numChars);
+
+ } else {
+ // Empty string case
+ if (BufferLength >= sizeof(wchar_t)) {
+ *((wchar_t *) CharacterAttribute) = L'\0';
+ }
+ if (StringLength) {
+ *StringLength = 0;
+ }
+ LOG("Set empty Unicode string");
+ }
+ } else if (StringLength) {
+ // Just return required length if no buffer provided
+ *StringLength = ansiLength * sizeof(wchar_t);
+ LOGF("Returning required buffer length: %d bytes", *StringLength);
+ }
+
+ return ret;
}
-} // extern "C"
\ No newline at end of file
+} // extern "C"
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp
index 25341ff..960dff5 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp
@@ -3,7 +3,6 @@
#define WIN32_LEAN_AND_MEAN
#include "../include/connection.hpp"
#include "../include/logging.hpp"
-#include "../include/httplib.h"
#include "../include/globals.hpp"
#include "../include/environment.hpp"
#include "../include/statement.hpp"
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp
index 47f3fff..173d261 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp
@@ -10,63 +10,54 @@
extern "C" {
- SQLRETURN SQL_API SQLExecDirect_A(
- SQLHSTMT hstmt,
- SQLCHAR* szSqlStr,
- SQLINTEGER cbSqlStr)
- {
- std::string sqlStr(reinterpret_cast(szSqlStr),
- cbSqlStr == SQL_NTS ? strlen(reinterpret_cast(szSqlStr)) : cbSqlStr);
+ SQLRETURN SQL_API SQLExecDirect(
+ SQLHSTMT StatementHandle,
+ SQLCHAR* StatementText,
+ SQLINTEGER TextLength) {
- auto stmt = static_cast(hstmt);
- if (!stmt) {
+ LOGF("SQLExecDirect called with query: %s", StatementText);
+
+ auto* stmt = static_cast(StatementHandle);
+ if (!stmt || !stmt->conn) {
+ LOG("Invalid statement handle or connection");
return SQL_INVALID_HANDLE;
}
- return SQL_ERROR;
- }
-
- SQLRETURN SQL_API SQLExecDirect_W(
- SQLHSTMT hstmt,
- SQLWCHAR* szSqlStr,
- SQLINTEGER cbSqlStr)
- {
- // Get the actual string length
- size_t actualLength = 0;
- while (szSqlStr[actualLength] != L'\0' &&
- !(szSqlStr[actualLength] == L'\\' && szSqlStr[actualLength + 1] == L'0')) {
- actualLength++;
- }
-
- // Create wide string excluding the "\ 0" terminator
- std::wstring wsqlStr(szSqlStr, actualLength);
- std::string sqlStr = WideStringToString(wsqlStr);
+ // Convert to string based on TextLength
+ std::string query;
+ if (TextLength == SQL_NTS) {
+ query = reinterpret_cast(StatementText);
+ } else {
+ query = std::string(reinterpret_cast(StatementText), TextLength);
+ }
- LOGF("SQLExecDirect_W: Original SQL: '%s'", sqlStr.c_str());
+ LOGF("Executing query: %s", query.c_str());
+ return stmt->conn->Query(query, stmt);
+ }
- // Transform the query
- std::string transformedSql = sqlStr;
+ // And the Unicode version
+ SQLRETURN SQL_API SQLExecDirectW(
+ SQLHSTMT StatementHandle,
+ SQLWCHAR* StatementText,
+ SQLINTEGER TextLength) {
- // Replace INFORMATION_SCHEMA.TABLES
- size_t pos = transformedSql.find("INFORMATION_SCHEMA.TABLES");
- if (pos != std::string::npos) {
- transformedSql.replace(pos, std::string("INFORMATION_SCHEMA.TABLES").length(), "metadata.TABLES");
+ auto* stmt = static_cast(StatementHandle);
+ if (!stmt || !stmt->conn) {
+ LOG("Invalid statement handle or connection");
+ return SQL_INVALID_HANDLE;
}
- // Replace INFORMATION_SCHEMA.COLUMNS
- pos = transformedSql.find("INFORMATION_SCHEMA.COLUMNS");
- if (pos != std::string::npos) {
- transformedSql.replace(pos, std::string("INFORMATION_SCHEMA.COLUMNS").length(), "metadata.COLUMNS");
+ // Convert wide string to UTF-8
+ std::wstring wquery;
+ if (TextLength == SQL_NTS) {
+ wquery = reinterpret_cast(StatementText);
+ } else {
+ wquery = std::wstring(reinterpret_cast(StatementText), TextLength);
}
- LOGF("SQLExecDirect_W: Transformed SQL: '%s'", transformedSql.c_str());
-
- auto stmt = static_cast(hstmt);
- if (!stmt) {
- LOG("SQLExecDirect_W: Invalid statement handle");
- return SQL_INVALID_HANDLE;
- }
+ std::string query = WideStringToString(wquery);
+ LOGF("SQLExecDirectW executing query: %s", query.c_str());
- return SQL_ERROR;
+ return stmt->conn->Query(query, stmt);
}
} // extern "C"
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp
index 473564d..61e8e5a 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp
@@ -52,31 +52,125 @@ SQLRETURN SQL_API SQLSetDescField(
}
switch (FieldIdentifier) {
+ case SQL_COLUMN_TYPE:
case SQL_DESC_TYPE:
- case SQL_DESC_CONCISE_TYPE:
+ // case SQL_DESC_CONCISE_TYPE:
stmt->resultColumns[colIdx].sqlType = reinterpret_cast(Value);
break;
+ case SQL_COLUMN_NAME:
case SQL_DESC_NAME:
if (Value) {
+ LOGF("Setting column name to: %s", static_cast(Value));
stmt->resultColumns[colIdx].name = static_cast(Value);
stmt->resultColumns[colIdx].nameLength =
(BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].name) : BufferLength;
+ LOGF("Set column name to: %s", stmt->resultColumns[colIdx].name);
}
break;
+ // case SQL_COLUMN_LABEL:
+ case SQL_DESC_LABEL:
+ if (Value) {
+ stmt->resultColumns[colIdx].label = static_cast(Value);
+ stmt->resultColumns[colIdx].labelLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].label) : BufferLength;
+ }
+ break;
+
+ case SQL_COLUMN_NULLABLE:
case SQL_DESC_NULLABLE:
stmt->resultColumns[colIdx].nullable = reinterpret_cast(Value);
break;
-
+ case SQL_COLUMN_LENGTH:
case SQL_DESC_LENGTH:
stmt->resultColumns[colIdx].columnSize = reinterpret_cast(Value);
break;
+ case SQL_COLUMN_PRECISION:
case SQL_DESC_PRECISION:
stmt->resultColumns[colIdx].decimalDigits = reinterpret_cast(Value);
break;
+ case SQL_COLUMN_SCALE:
+ case SQL_DESC_SCALE:
+ stmt->resultColumns[colIdx].scale = reinterpret_cast(Value);
+ break;
+
+ case SQL_CATALOG_NAME:
+ case SQL_DESC_CATALOG_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].catalogName = static_cast(Value);
+ stmt->resultColumns[colIdx].catalogNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].catalogName) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_SCHEMA_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].schemaName = static_cast(Value);
+ stmt->resultColumns[colIdx].schemaNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].schemaName) : BufferLength;
+ }
+ break;
+
+ // case SQL_COLUMN_TABLE_NAME:
+ case SQL_DESC_TABLE_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].tableName = static_cast(Value);
+ stmt->resultColumns[colIdx].tableNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].tableName) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_BASE_COLUMN_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].baseColumnName = static_cast(Value);
+ stmt->resultColumns[colIdx].baseColumnNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].baseColumnName) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_BASE_TABLE_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].baseTableName = static_cast(Value);
+ stmt->resultColumns[colIdx].baseTableNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].baseTableName) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_LITERAL_PREFIX:
+ if (Value) {
+ stmt->resultColumns[colIdx].literalPrefix = static_cast(Value);
+ stmt->resultColumns[colIdx].literalPrefixLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].literalPrefix) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_LITERAL_SUFFIX:
+ if (Value) {
+ stmt->resultColumns[colIdx].literalSuffix = static_cast(Value);
+ stmt->resultColumns[colIdx].literalSuffixLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].literalSuffix) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_LOCAL_TYPE_NAME:
+ if (Value) {
+ stmt->resultColumns[colIdx].localTypeName = static_cast(Value);
+ stmt->resultColumns[colIdx].localTypeNameLength =
+ (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].localTypeName) : BufferLength;
+ }
+ break;
+
+ case SQL_DESC_UNNAMED:
+ stmt->resultColumns[colIdx].unnamed = reinterpret_cast(Value);
+ break;
+
+ case SQL_DESC_DISPLAY_SIZE:
+ stmt->resultColumns[colIdx].displaySize = reinterpret_cast(Value);
+ break;
+
default:
LOGF("Unsupported field identifier: %d", FieldIdentifier);
return SQL_ERROR;
@@ -85,14 +179,26 @@ SQLRETURN SQL_API SQLSetDescField(
return SQL_SUCCESS;
}
- SQLRETURN SQL_API SQLSetDescFieldW(
- SQLHDESC DescriptorHandle,
- SQLSMALLINT RecNumber,
- SQLSMALLINT FieldIdentifier,
- SQLPOINTER Value,
- SQLINTEGER BufferLength) {
+SQLRETURN SQL_API SQLSetDescFieldW(
+ SQLHDESC DescriptorHandle,
+ SQLSMALLINT RecNumber,
+ SQLSMALLINT FieldIdentifier,
+ SQLPOINTER Value,
+ SQLINTEGER BufferLength) {
if (FieldIdentifier == SQL_DESC_NAME && Value != nullptr) {
+ // Convert wide string to narrow for internal storage
+ auto wstr = static_cast(Value);
+ std::string name = WideStringToString(wstr);
+ LOGF("Converted name: %s", name.c_str());
+ auto result = SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(name.c_str()), SQL_NTS);
+ Statement *s = static_cast(DescriptorHandle);
+ LOGF("Post SQLSetDescField: %s", s->resultColumns[RecNumber -1].name);
+ return result;
+ }
+
+ if (FieldIdentifier == SQL_COLUMN_NAME && Value != nullptr) {
// Convert wide string to narrow for internal storage
const wchar_t* wstr = static_cast(Value);
std::string name = WideStringToString(std::wstring(wstr));
@@ -100,6 +206,56 @@ SQLRETURN SQL_API SQLSetDescField(
const_cast(name.c_str()), SQL_NTS);
}
+ if (FieldIdentifier == SQL_DESC_CATALOG_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string catalogName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(catalogName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_SCHEMA_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string schemaName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(schemaName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_TABLE_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string tableName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(tableName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_BASE_COLUMN_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string baseColumnName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(baseColumnName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_BASE_TABLE_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string baseTableName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(baseTableName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_LOCAL_TYPE_NAME && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string localTypeName = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(localTypeName.c_str()), SQL_NTS);
+ }
+
+ if (FieldIdentifier == SQL_DESC_LABEL && Value != nullptr) {
+ const wchar_t* wstr = static_cast(Value);
+ std::string label = WideStringToString(std::wstring(wstr));
+ return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier,
+ const_cast(label.c_str()), SQL_NTS);
+ }
+
+
return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, Value, BufferLength);
}
}
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp
index d453a01..fc37a31 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp
@@ -1,182 +1,174 @@
-// Windows includes must come first
-#define WIN32_LEAN_AND_MEAN
-#include
-#include
-
-// ODBC includes
-#include
-
-// JNI includes
-#include
+#include "../include/connection.hpp"
-// Standard includes
-#include
-#include
-#include
-#include
+#include
-// Project includes
-#include "../include/connection.hpp"
#include "../include/statement.hpp"
-#include "../include/globals.hpp"
#include "../include/logging.hpp"
+#include
+#include
-bool Connection::initJVM() {
- LOG("Initializing JVM");
+#include "JniParam.hpp"
- JavaVMInitArgs vm_args;
- JavaVMOption options[5];
- // Add the path to your JAR files
- std::string classPath = "-Djava.class.path=";
- classPath += dllPath + "\\jni-arrow-1.0.0-jar-with-dependencies.jar";
- LOGF("Java classpath: %s", classPath.c_str());
+HMODULE GetCurrentModule() {
+ HMODULE hModule = nullptr;
+ GetModuleHandleEx(
+ GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ reinterpret_cast(GetCurrentModule),
+ &hModule);
+ return hModule;
+}
- options[0].optionString = const_cast(classPath.c_str());
+std::string GetModuleDirectory() {
+ char path[MAX_PATH];
+ HMODULE hModule = GetCurrentModule();
+ GetModuleFileNameA(hModule, path, MAX_PATH);
+ std::string modulePath(path);
+ size_t pos = modulePath.find_last_of("\\/");
+ return (std::string::npos == pos) ? "" : modulePath.substr(0, pos);
+}
- // Add any necessary JVM options
- std::string extraOptions = "-Xmx512m";
- options[1].optionString = const_cast(extraOptions.c_str());
- options[2].optionString = const_cast("--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED");
- options[3].optionString = const_cast("-Dotel.java.global-autoconfigure.enabled=true");
- options[4].optionString = const_cast("-Dlog_level=debug");
+Connection::~Connection() {
+ if (connected) {
+ disconnect();
+ }
+}
- // Set the JNI version to 1.8 for OpenJDK 11
- vm_args.version = JNI_VERSION_1_8;
- vm_args.nOptions = 2;
- vm_args.options = options;
- vm_args.ignoreUnrecognized = JNI_TRUE;
+// std::string WideStringToString(const std::wstring& wstr) {
+// LOG("WideStringToString: Input length: " + std::to_string(wstr.length()));
+//
+// if (wstr.empty()) {
+// LOG("WideStringToString: Empty input string");
+// return {};
+// }
+//
+// // Print first few wide chars as numbers for debugging
+// char intbuf[100] = {0};
+// char* intptr = intbuf;
+// size_t maxChars = wstr.length() > 5 ? 5 : wstr.length();
+// for (size_t i = 0; i < maxChars; ++i) {
+// intptr += sprintf_s(intptr, 20, "%d ", static_cast(wstr[i]));
+// }
+// LOG("WideStringToString: First 5 chars (as ints): " + std::string(intbuf));
+//
+// // Get required buffer size
+// const int size_needed = WideCharToMultiByte(CP_UTF8, 0,
+// wstr.data(), static_cast(wstr.length()),
+// nullptr, 0,
+// nullptr, nullptr);
+//
+// if (size_needed <= 0) {
+// const DWORD error = GetLastError();
+// LOG("WideStringToString: Error calculating buffer size. Error code: " + std::to_string(error));
+// return std::string();
+// }
+//
+// LOG("WideStringToString: Allocating buffer of size: " + std::to_string(size_needed));
+// std::string strTo(size_needed, 0);
+//
+// int result = WideCharToMultiByte(CP_UTF8, 0,
+// wstr.data(), static_cast(wstr.length()),
+// &strTo[0], size_needed,
+// nullptr, nullptr);
+//
+// if (result <= 0) {
+// DWORD error = GetLastError();
+// LOG("WideStringToString: Conversion failed. Error code: " + std::to_string(error));
+// return std::string();
+// }
+//
+// LOG("WideStringToString: Converted string length: " + std::to_string(strTo.length()));
+// LOG("WideStringToString: Converted string: '" + strTo + "'");
+//
+// return strTo;
+// }
+
+#include // For wcstombs
+#include
+#include
+
+// Function to convert std::wstring to std::string
+std::string WideStringToString(const std::wstring& wstr) {
+ // Determine the required buffer size for the multibyte string
+ size_t len = wcstombs(nullptr, wstr.c_str(), 0);
+ if (len == static_cast(-1)) {
+ throw std::runtime_error("Conversion error");
+ }
- if (const jint rc = JNI_CreateJavaVM(&jvm, reinterpret_cast(&env), &vm_args); rc != JNI_OK) {
- LOGF("Failed to create JVM: %d", rc);
- LOGF("JNI_CreateJavaVM return code: %d", rc);
- return false;
+ // Create a buffer of the appropriate size
+ std::vector buffer(len + 1);
+ wcstombs(buffer.data(), wstr.c_str(), buffer.size());
+
+ // Convert the buffer to std::string
+ return std::string(buffer.data());
+}
+
+// Function to convert std::string to std::wstring
+std::wstring StringToWideString(const std::string& str) {
+ // Determine the required buffer size for the wide character string
+ size_t len = mbstowcs(nullptr, str.c_str(), 0);
+ if (len == static_cast(-1)) {
+ throw std::runtime_error("Conversion error");
}
- LOG("JVM created successfully");
- return true;
+ // Create a buffer of the appropriate size
+ std::vector buffer(len + 1);
+ mbstowcs(buffer.data(), str.c_str(), buffer.size());
+
+ // Convert the buffer to std::wstring
+ return std::wstring(buffer.data());
}
void Connection::setConnectionString(const std::string& dsn, const std::string& uid, const std::string& authStr) {
- // Store connection string
std::ostringstream ss;
ss << "DSN=" << dsn << ";UID=" << uid << ";PWD=" << authStr;
this->connectionString = ss.str();
}
void Connection::setConnectionString(const std::string& dsn) {
- this->connectionString =dsn;
-}
-
-bool Connection::initWrapper(const ConnectionParams& params) {
- try {
- // Find the wrapper class
- wrapperClass = env->FindClass("com/hasura/ArrowJdbcWrapper");
- if (!wrapperClass) {
- LOG("Failed to find ArrowJdbcWrapper class");
- return false;
- }
- LOG("Found ArrowJdbcWrapper class");
-
- // Get the method ID for the constructor
- jmethodID constructor = env->GetMethodID(wrapperClass, "", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V");
- if (constructor == nullptr) {
- LOG("Failed to find ArrowJdbcWrapper constructor");
- return false;
- }
- LOG("Found ArrowJdbcWrapper constructor");
-
- // Get the method ID for healthCheck
- jmethodID healthCheckMethod = env->GetMethodID(wrapperClass, "healthCheck", "()Z");
- if (healthCheckMethod == nullptr) {
- LOG("Failed to find healthCheck method");
- return false;
- }
- LOG("Found healthCheck method");
-
- // Create Java strings for the constructor arguments
- const std::string jdbcUrl = buildJdbcUrl(params);
- jstring jvm_jdbcUrl = env->NewStringUTF(jdbcUrl.c_str()); // Replace JDBC_URL with your actual URL
- jstring username = env->NewStringUTF(params.uid.c_str()); // Replace USERNAME with your actual username
- jstring password = env->NewStringUTF(params.pwd.c_str()); // Replace PASSWORD with your actual password
-
- // Create a new instance of ArrowJdbcWrapper
- LOGF("jdbcUrl %s, username %s, password, %s", jdbcUrl.c_str(), params.uid.c_str(), params.pwd.c_str());
- wrapperInstance = env->NewObject(wrapperClass, constructor, jvm_jdbcUrl, username, password);
- if (wrapperInstance == nullptr) {
- LOG("Failed to construct ArrowJdbcWrapper");
- // Check if an exception was thrown during the object creation
- if (env->ExceptionCheck()) {
- env->ExceptionDescribe(); // Print the exception details to stderr
- env->ExceptionClear(); // Clear the exception so that we can continue
- LOG("Failed to construct ArrowJdbcWrapper due to an exception");
- return false;
- }
- return false;
- }
- LOG("Initialized ArrowJdbcWrapper");
-
- // Call the healthCheck method
- const jboolean healthStatus = env->CallBooleanMethod(wrapperInstance, healthCheckMethod);
- if (env->ExceptionCheck()) {
- env->ExceptionDescribe();
- env->ExceptionClear();
- LOG("Exception occurred while calling healthCheck method");
- return false;
- }
- bool isHealthy = (healthStatus == JNI_TRUE);
- LOGF("Health check status: %d", isHealthy);
-
- return isHealthy;
- }
- catch (const std::exception& e) {
- LOGF("Exception in initWrapper: %s", e.what());
- return false;
- }
+ this->connectionString = dsn;
}
SQLRETURN Connection::connect() {
- if (isConnected) {
+ if (connected) {
LOG("Already connected");
return SQL_ERROR;
}
- // Parse connection parameters
ConnectionParams params;
if (!parseConnectionString(connectionString, params)) {
LOG("Failed to parse connection string");
return SQL_ERROR;
}
- // Initialize JVM if not already initialized
if (!jvm && !initJVM()) {
LOG("Failed to initialize JVM");
return SQL_ERROR;
}
- // Initialize wrapper with connection parameters
if (!initWrapper(params)) {
LOG("Failed to initialize wrapper");
return SQL_ERROR;
}
- isConnected = true;
- LOG("Connected!");
+ connected = true;
+ LOG("Connected successfully");
return SQL_SUCCESS;
}
SQLRETURN Connection::disconnect() {
- if (!isConnected) {
+ if (!connected) {
LOG("Not connected");
return SQL_ERROR;
}
+ cleanupActiveStmts();
+
if (env && wrapperInstance) {
- // Call close on the wrapper instance
jmethodID closeMethod = env->GetMethodID(wrapperClass, "close", "()V");
env->CallVoidMethod(wrapperInstance, closeMethod);
- // Delete global references
env->DeleteGlobalRef(wrapperInstance);
env->DeleteGlobalRef(wrapperClass);
@@ -190,100 +182,169 @@ SQLRETURN Connection::disconnect() {
env = nullptr;
}
- isConnected = false;
+ connected = false;
return SQL_SUCCESS;
}
-Connection::~Connection() {
- if (isConnected) {
- disconnect();
+SQLRETURN Connection::Query(const std::string& query, Statement* stmt) {
+ stmt->setOriginalQuery(query);
+ std::string interpolatedQuery = stmt->buildInterpolatedQuery();
+ return executeAndGetArrowResult("executeQuery", {JniParam(interpolatedQuery)}, stmt);
+}
+
+SQLRETURN Connection::GetTables(
+ const std::string& catalogName,
+ const std::string& schemaName,
+ const std::string& tableName,
+ const std::string& tableType,
+ Statement* stmt) const {
+
+ std::vector types;
+ if (!tableType.empty()) {
+ std::istringstream ss(tableType);
+ std::string type;
+ while (std::getline(ss, type, ',')) {
+ type.erase(0, type.find_first_not_of(' '));
+ type.erase(type.find_last_not_of(' ') + 1);
+ if (!type.empty()) {
+ types.push_back(type);
+ }
+ }
}
+
+ return executeAndGetArrowResult("getTables", {
+ JniParam(catalogName),
+ JniParam(schemaName),
+ JniParam(tableName),
+ JniParam(types) // This will be converted to String[]
+ }, stmt);
}
-// Helper functions
-std::string GetModuleDirectory(HMODULE hModule) {
- char path[MAX_PATH];
- GetModuleFileNameA(hModule, path, MAX_PATH);
- std::string modulePath(path);
- size_t pos = modulePath.find_last_of("\\/");
- return (std::string::npos == pos) ? "" : modulePath.substr(0, pos);
+SQLRETURN Connection::GetColumns(
+ const std::string& catalogName,
+ const std::string& schemaName,
+ const std::string& tableName,
+ const std::string& columnName,
+ Statement* stmt) const {
+
+ return executeAndGetArrowResult("getColumns", {
+ JniParam(catalogName),
+ JniParam(schemaName),
+ JniParam(tableName),
+ JniParam(columnName)
+ }, stmt);
}
-std::string WideStringToString(const std::wstring& wstr) {
- LOG("WideStringToString: Input length: " + std::to_string(wstr.length()));
-
- if (wstr.empty()) {
- LOG("WideStringToString: Empty input string");
- return {};
- }
+bool Connection::hasActiveStmts() const {
+ return !activeStmts.empty();
+}
- // Print first few wide chars as numbers for debugging
- char intbuf[100] = {0};
- char* intptr = intbuf;
- size_t maxChars = wstr.length() > 5 ? 5 : wstr.length();
- for (size_t i = 0; i < maxChars; ++i) {
- intptr += sprintf_s(intptr, 20, "%d ", static_cast(wstr[i]));
- }
- LOG("WideStringToString: First 5 chars (as ints): " + std::string(intbuf));
-
- // Get required buffer size
- const int size_needed = WideCharToMultiByte(CP_UTF8, 0,
- wstr.data(), static_cast(wstr.length()),
- nullptr, 0,
- nullptr, nullptr);
-
- if (size_needed <= 0) {
- const DWORD error = GetLastError();
- LOG("WideStringToString: Error calculating buffer size. Error code: " + std::to_string(error));
- return std::string();
+void Connection::cleanupActiveStmts() {
+ for (auto stmt : activeStmts) {
+ if (stmt) {
+ stmt->clearResults();
+ }
}
-
- LOG("WideStringToString: Allocating buffer of size: " + std::to_string(size_needed));
- std::string strTo(size_needed, 0);
-
- int result = WideCharToMultiByte(CP_UTF8, 0,
- wstr.data(), static_cast(wstr.length()),
- &strTo[0], size_needed,
- nullptr, nullptr);
-
- if (result <= 0) {
- DWORD error = GetLastError();
- LOG("WideStringToString: Conversion failed. Error code: " + std::to_string(error));
- return std::string();
+ activeStmts.clear();
+}
+
+bool Connection::initJVM() {
+ LOG("Initializing JVM");
+
+ JavaVMInitArgs vm_args;
+ JavaVMOption options[5];
+
+ std::string classPath = "-Djava.class.path=";
+ classPath += GetModuleDirectory() + "\\jni-arrow-1.0.0-jar-with-dependencies.jar";
+ LOGF("Java classpath: %s", classPath.c_str());
+
+ options[0].optionString = const_cast(classPath.c_str());
+ options[1].optionString = const_cast("-Xmx512m");
+ options[2].optionString = const_cast("--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED");
+ options[3].optionString = const_cast("-Dotel.java.global-autoconfigure.enabled=true");
+ options[4].optionString = const_cast("-Dlog_level=warn");
+
+ vm_args.version = JNI_VERSION_1_8;
+ vm_args.nOptions = 5;
+ vm_args.options = options;
+ vm_args.ignoreUnrecognized = JNI_TRUE;
+
+ jint rc = JNI_CreateJavaVM(&jvm, reinterpret_cast(&env), &vm_args);
+ if (rc != JNI_OK) {
+ LOGF("Failed to create JVM: %d", rc);
+ return false;
}
-
- LOG("WideStringToString: Converted string length: " + std::to_string(strTo.length()));
- LOG("WideStringToString: Converted string: '" + strTo + "'");
-
- return strTo;
+
+ LOG("JVM created successfully");
+ return true;
}
-bool Connection::parseConnectionString(const std::string& connStr, ConnectionParams& params) {
- if (connStr.empty()) {
- LOG("Empty connection string provided");
+bool Connection::initWrapper(const ConnectionParams& params) {
+ try {
+ wrapperClass = env->FindClass("com/hasura/ArrowJdbcWrapper");
+ if (!wrapperClass) {
+ LOG("Failed to find ArrowJdbcWrapper class");
+ return false;
+ }
+
+ wrapperClass = reinterpret_cast(env->NewGlobalRef(wrapperClass));
+
+ jmethodID constructor = env->GetMethodID(wrapperClass, "",
+ "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V");
+ if (!constructor) {
+ LOG("Failed to find constructor");
+ return false;
+ }
+
+ const std::string jdbcUrl = buildJdbcUrl(params);
+ jstring jvm_jdbcUrl = env->NewStringUTF(jdbcUrl.c_str());
+ jstring username = env->NewStringUTF(params.uid.c_str());
+ jstring password = env->NewStringUTF(params.pwd.c_str());
+
+ jobject localInstance = env->NewObject(wrapperClass, constructor, jvm_jdbcUrl, username, password);
+ if (!localInstance) {
+ LOG("Failed to create wrapper instance");
+ return false;
+ }
+
+ wrapperInstance = env->NewGlobalRef(localInstance);
+ env->DeleteLocalRef(localInstance);
+
+ env->DeleteLocalRef(jvm_jdbcUrl);
+ env->DeleteLocalRef(username);
+ env->DeleteLocalRef(password);
+
+ jmethodID healthCheckMethod = env->GetMethodID(wrapperClass, "healthCheck", "()Z");
+ if (!healthCheckMethod) {
+ LOG("Failed to find healthCheck method");
+ return false;
+ }
+
+ jboolean health = env->CallBooleanMethod(wrapperInstance, healthCheckMethod);
+ return health == JNI_TRUE;
+ }
+ catch (const std::exception& e) {
+ LOGF("Exception in initWrapper: %s", e.what());
return false;
}
+}
- LOGF("Parsing connection string: %s", connStr.c_str());
+bool Connection::parseConnectionString(const std::string& connStr, ConnectionParams& params) {
std::istringstream ss(connStr);
std::string token;
while (std::getline(ss, token, ';')) {
- auto pos = token.find('=');
- if (pos == std::string::npos) {
- LOGF("Invalid token: %s", token.c_str());
- continue;
- }
+ size_t pos = token.find('=');
+ if (pos == std::string::npos) continue;
std::string key = token.substr(0, pos);
std::string value = token.substr(pos + 1);
- while (!key.empty() && key[0] == ' ') key.erase(0, 1);
- while (!key.empty() && key.back() == ' ') key.pop_back();
- while (!value.empty() && value[0] == ' ') value.erase(0, 1);
- while (!value.empty() && value.back() == ' ') value.pop_back();
-
- LOGF("Parsed key-value pair: %s=%s", key.c_str(), value.c_str());
+ // Trim whitespace
+ key.erase(0, key.find_first_not_of(' '));
+ key.erase(key.find_last_not_of(' ') + 1);
+ value.erase(0, value.find_first_not_of(' '));
+ value.erase(value.find_last_not_of(' ') + 1);
if (key == "Server") params.server = value;
else if (key == "Port") params.port = value;
@@ -294,45 +355,14 @@ bool Connection::parseConnectionString(const std::string& connStr, ConnectionPar
else if (key == "PWD" || key == "Password") params.pwd = value;
else if (key == "Encrypt") params.encrypt = value;
else if (key == "Timeout") params.timeout = value;
- else if (key == "DRIVER") {
- LOGF("Ignoring key: %s", key.c_str());
- }
- else {
- LOGF("Unknown key: %s", key.c_str());
- }
}
- LOGF("Server: %s, Port: %s, Database: %s",
- params.server.c_str(), params.port.c_str(), params.database.c_str());
-
- if (!params.isValid()) {
- LOG("Connection parameters are not valid");
- return false;
- }
-
- return true;
-}
-
-bool Connection::hasActiveStmts() const {
- // Return true if there are any active statements
- return !activeStmts.empty();
-}
-
-void Connection::cleanupActiveStmts() {
- // Clean up any active statements
- for (auto stmt : activeStmts) {
- if (stmt) {
- stmt->clearResults();
- // You might want to also free the statement here
- // depending on your memory management strategy
- }
- }
- activeStmts.clear();
+ return params.isValid();
}
std::string Connection::buildJdbcUrl(const ConnectionParams& params) {
std::string protocol = (params.encrypt == "yes") ? "https" : "http";
-
+
std::ostringstream url;
url << "jdbc:graphql:" << protocol << "://"
<< params.server << ":" << params.port << "/"
@@ -356,263 +386,221 @@ std::string Connection::buildJdbcUrl(const ConnectionParams& params) {
return url.str();
}
-SQLRETURN Connection::GetTables(
- const std::string& catalogName,
- const std::string& schemaName,
- const std::string& tableName,
- const std::string& tableType,
- Statement* stmt) const {
-
- LOGF("GetTables called with params - catalog: '%s', schema: '%s', table: '%s', types: '%s'",
- catalogName.c_str(), schemaName.c_str(), tableName.c_str(), tableType.c_str());
-
- if (!isConnected || !env || !wrapperInstance || !wrapperClass) {
- LOG("Connection validation failed:");
- LOGF(" - isConnected: %d", isConnected);
- LOGF(" - env: %p", (void*)env);
- LOGF(" - wrapperInstance: %p", (void*)wrapperInstance);
- LOGF(" - wrapperClass: %p", (void*)wrapperClass);
- return SQL_ERROR;
- }
-
+SQLRETURN Connection::populateColumnDescriptors(jobject schemaRoot, Statement* stmt) const {
try {
- // Helper lambda to convert std::string to jstring
- auto createJavaString = [this](const std::string& str) -> jstring {
- if (str.empty()) {
- LOG("Creating null jstring for empty input");
- return nullptr;
- }
- LOGF("Creating jstring for input: '%s'", str.c_str());
- return env->NewStringUTF(str.c_str());
- };
-
- // Convert input strings to Java strings
- LOG("Converting input parameters to Java strings");
- jstring jCatalog = createJavaString(catalogName);
- LOGF("Created jCatalog: %p", (void*)jCatalog);
-
- jstring jSchema = createJavaString(schemaName);
- LOGF("Created jSchema: %p", (void*)jSchema);
-
- jstring jTable = createJavaString(tableName);
- LOGF("Created jTable: %p", (void*)jTable);
-
- // Handle table types
- jobjectArray jTypes = nullptr;
- if (!tableType.empty()) {
- LOG("Processing table types string");
- std::vector typeList;
- std::istringstream typeStream(tableType);
- std::string type;
-
- // Parse comma-separated types
- while (std::getline(typeStream, type, ',')) {
- // Trim whitespace
- type.erase(0, type.find_first_not_of(' '));
- type.erase(type.find_last_not_of(' ') + 1);
- if (!type.empty()) {
- LOGF("Found table type: '%s'", type.c_str());
- typeList.push_back(type);
- }
- }
-
- if (!typeList.empty()) {
- LOGF("Creating Java string array with %zu types", typeList.size());
- jclass stringClass = env->FindClass("java/lang/String");
- if (!stringClass) {
- LOG("Failed to find java.lang.String class");
- return SQL_ERROR;
- }
-
- jTypes = env->NewObjectArray(typeList.size(), stringClass, env->NewStringUTF(""));
- if (!jTypes) {
- LOG("Failed to create Java string array for types");
- return SQL_ERROR;
- }
-
- for (size_t i = 0; i < typeList.size(); i++) {
- LOGF("Setting array element %zu to '%s'", i, typeList[i].c_str());
- jstring typeStr = env->NewStringUTF(typeList[i].c_str());
- env->SetObjectArrayElement(jTypes, i, typeStr);
- env->DeleteLocalRef(typeStr);
- }
- LOGF("Successfully created jTypes array: %p", (void*)jTypes);
- }
- } else {
- LOG("No table types specified");
- }
-
- // Find getTables method
- LOG("Looking up getTables method");
- jmethodID getTablesMethod = env->GetMethodID(wrapperClass, "getTables",
- "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;[Ljava/lang/String;)Lorg/apache/arrow/vector/VectorSchemaRoot;");
-
- if (!getTablesMethod) {
- LOG("Failed to find getTables method");
+ // Extract schema and get the number of fields
+ jclass rootClass = env->GetObjectClass(schemaRoot);
+ jmethodID getSchemaMethod = env->GetMethodID(rootClass, "getSchema",
+ "()Lorg/apache/arrow/vector/types/pojo/Schema;");
+ jobject schema = env->CallObjectMethod(schemaRoot, getSchemaMethod);
+
+ jclass schemaClass = env->GetObjectClass(schema);
+ jmethodID getFieldsMethod = env->GetMethodID(schemaClass, "getFields", "()Ljava/util/List;");
+ jobject fieldsList = env->CallObjectMethod(schema, getFieldsMethod);
+
+ jclass listClass = env->GetObjectClass(fieldsList);
+ jmethodID sizeMethod = env->GetMethodID(listClass, "size", "()I");
+ jint numFields = env->CallIntMethod(fieldsList, sizeMethod);
+ LOGF("Schema contains %d fields", numFields);
+
+ // Get the IRD handle
+ SQLHDESC hIRD = nullptr;
+ SQLRETURN ret = SQLGetStmtAttr(stmt, SQL_ATTR_IMP_ROW_DESC, &hIRD, 0, nullptr);
+ if (!SQL_SUCCEEDED(ret) || !hIRD) {
+ LOG("Failed to get IRD handle");
return SQL_ERROR;
}
- LOGF("Found getTables method: %p", (void*)getTablesMethod);
-
- // Call getTables
- LOG("Calling getTables method");
- jobject schemaRoot = env->CallObjectMethod(wrapperInstance, getTablesMethod,
- jCatalog, jSchema, jTable, jTypes);
-
- // Clean up local references
- LOG("Cleaning up local references");
- if (jCatalog) {
- env->DeleteLocalRef(jCatalog);
- LOG("Deleted jCatalog reference");
- }
- if (jSchema) {
- env->DeleteLocalRef(jSchema);
- LOG("Deleted jSchema reference");
- }
- if (jTable) {
- env->DeleteLocalRef(jTable);
- LOG("Deleted jTable reference");
- }
- if (jTypes) {
- env->DeleteLocalRef(jTypes);
- LOG("Deleted jTypes reference");
+
+ // Set up descriptors for each column
+ stmt->resultColumns.resize(numFields);
+ for (jint i = 0; i < numFields; i++) {
+ jobject field = env->CallObjectMethod(fieldsList, env->GetMethodID(listClass, "get", "(I)Ljava/lang/Object;"), i);
+ jclass fieldClass = env->GetObjectClass(field);
+
+ // Get field name
+ jmethodID getNameMethod = env->GetMethodID(fieldClass, "getName", "()Ljava/lang/String;");
+ jstring fieldName = (jstring)env->CallObjectMethod(field, getNameMethod);
+ const char* nameChars = env->GetStringUTFChars(fieldName, nullptr);
+
+ // Get field type
+ jmethodID getTypeMethod = env->GetMethodID(fieldClass, "getType",
+ "()Lorg/apache/arrow/vector/types/pojo/ArrowType;");
+ jobject arrowType = env->CallObjectMethod(field, getTypeMethod);
+
+ SQLSMALLINT sqlType = mapArrowTypeToSQL(env, arrowType);
+ SQLULEN columnSize = getSQLTypeSize(sqlType);
+
+ stmt->resultColumns[i].name = _strdup(nameChars);
+ stmt->resultColumns[i].nameLength = (SQLSMALLINT)strlen(nameChars);
+ stmt->resultColumns[i].nullable = SQL_NULLABLE;
+ stmt->resultColumns[i].columnSize = columnSize;
+ stmt->resultColumns[i].sqlType = sqlType;
+
+ env->ReleaseStringUTFChars(fieldName, nameChars);
+ env->DeleteLocalRef(fieldName);
+ env->DeleteLocalRef(arrowType);
+ env->DeleteLocalRef(field);
+ env->DeleteLocalRef(fieldClass);
}
+ env->DeleteLocalRef(listClass);
+ env->DeleteLocalRef(fieldsList);
+ env->DeleteLocalRef(schemaClass);
+ env->DeleteLocalRef(schema);
+ env->DeleteLocalRef(rootClass);
+
+ return SQL_SUCCESS;
+ } catch (const std::exception& e) {
+ LOGF("Exception in populateColumnDescriptors: %s", e.what());
if (env->ExceptionCheck()) {
- LOG("Java exception detected during getTables call");
env->ExceptionDescribe();
env->ExceptionClear();
- LOG("Exception cleared");
- return SQL_ERROR;
}
-
- if (!schemaRoot) {
- LOG("getTables returned null schemaRoot");
- return SQL_ERROR;
- }
- LOGF("Successfully got schemaRoot: %p", (void*)schemaRoot);
-
- // Set up the ODBC result set columns in the statement
- LOG("Setting up table result columns in Statement");
- auto tableDefs = stmt->setupTableResultColumns();
- LOG("Successfully set up table result columns");
-
- // Convert Arrow VectorSchemaRoot to ODBC result set
- LOG("Converting Arrow VectorSchemaRoot to ODBC result set");
- SQLRETURN result = stmt->setArrowResult(schemaRoot, tableDefs);
- LOGF("setArrowResult returned: %d", result);
-
- // Clean up schema root
- env->DeleteLocalRef(schemaRoot);
- LOG("Cleaned up schemaRoot reference");
-
- if (result == SQL_SUCCESS) {
- LOG("GetTables completed successfully");
- } else {
- LOG("GetTables completed with errors");
- }
- return result;
-
- } catch (const std::exception& e) {
- LOGF("Exception in GetTables: %s", e.what());
- LOG("Stack trace (if available):");
- LOG(e.what());
return SQL_ERROR;
} catch (...) {
- LOG("Unknown exception in GetTables");
+ LOG("Unknown exception in populateColumnDescriptors");
+ if (env->ExceptionCheck()) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ }
return SQL_ERROR;
}
}
-SQLRETURN Connection::GetColumns(
- const std::string& catalogName,
- const std::string& schemaName,
- const std::string& tableName,
- const std::string& columnName,
+SQLRETURN Connection::executeAndGetArrowResult(
+ const char* methodName,
+ const std::vector& params,
Statement* stmt) const {
- LOGF("GetColumns called with params - catalog: '%s', schema: '%s', table: '%s', column: '%s'",
- catalogName.c_str(), schemaName.c_str(), tableName.c_str(), columnName.c_str());
-
- if (!isConnected || !env || !wrapperInstance || !wrapperClass) {
- LOG("Connection validation failed:");
- LOGF(" - isConnected: %d", isConnected);
- LOGF(" - env: %p", (void*)env);
- LOGF(" - wrapperInstance: %p", (void*)wrapperInstance);
- LOGF(" - wrapperClass: %p", (void*)wrapperClass);
+ if (!connected || !env || !wrapperInstance || !wrapperClass) {
+ LOG("Connection not properly initialized");
return SQL_ERROR;
}
try {
- // Helper lambda to convert std::string to jstring
- auto createJavaString = [this](const std::string& str) -> jstring {
- if (str.empty()) {
- LOG("Creating null jstring for empty input");
- return nullptr;
- }
- LOGF("Creating jstring for input: '%s'", str.c_str());
- return env->NewStringUTF(str.c_str());
- };
-
- // Convert input strings to Java strings
- LOG("Converting input parameters to Java strings");
- jstring jCatalog = createJavaString(catalogName);
- jstring jSchema = createJavaString(schemaName);
- jstring jTable = createJavaString(tableName);
- jstring jColumn = createJavaString(columnName);
-
- // Find getColumns method
- LOG("Looking up getColumns method");
- jmethodID getColumnsMethod = env->GetMethodID(wrapperClass, "getColumns",
- "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Lorg/apache/arrow/vector/VectorSchemaRoot;");
-
- if (!getColumnsMethod) {
- LOG("Failed to find getColumns method");
+ // Build method signature
+ std::string signature = "(";
+ for (const auto& param : params) {
+ signature += param.getSignature();
+ }
+ signature += ")Lorg/apache/arrow/vector/VectorSchemaRoot;";
+
+ LOGF("Looking for method %s with signature %s", methodName, signature.c_str());
+
+ jmethodID method = env->GetMethodID(wrapperClass, methodName, signature.c_str());
+ if (!method) {
+ LOGF("Failed to find method: %s with signature: %s", methodName, signature.c_str());
return SQL_ERROR;
}
- // Call getColumns
- LOG("Calling getColumns method");
- jobject schemaRoot = env->CallObjectMethod(wrapperInstance, getColumnsMethod,
- jCatalog, jSchema, jTable, jColumn);
+ // Convert parameters to JNI values
+ std::vector jniValues;
+ for (const auto& param : params) {
+ jniValues.push_back(param.toJValue(env));
+ }
- // Clean up local references
- LOG("Cleaning up local references");
- if (jCatalog) env->DeleteLocalRef(jCatalog);
- if (jSchema) env->DeleteLocalRef(jSchema);
- if (jTable) env->DeleteLocalRef(jTable);
- if (jColumn) env->DeleteLocalRef(jColumn);
+ // Call the method
+ jobject schemaRoot;
+ if (params.empty()) {
+ schemaRoot = env->CallObjectMethod(wrapperInstance, method);
+ } else {
+ schemaRoot = env->CallObjectMethodA(wrapperInstance, method, jniValues.data());
+ }
+
+ // Clean up parameters
+ for (size_t i = 0; i < params.size(); i++) {
+ params[i].cleanup(env, jniValues[i]);
+ }
if (env->ExceptionCheck()) {
- LOG("Java exception detected during getColumns call");
env->ExceptionDescribe();
env->ExceptionClear();
return SQL_ERROR;
}
if (!schemaRoot) {
- LOG("getColumns returned null schemaRoot");
+ LOG("Method returned null result");
return SQL_ERROR;
}
- // Set up the ODBC result set columns in the statement
- LOG("Setting up column result columns in Statement");
- auto columnDefs = stmt->setupColumnResultColumns();
-
- // Convert Arrow VectorSchemaRoot to ODBC result set
- LOG("Converting Arrow VectorSchemaRoot to ODBC result set");
- SQLRETURN result = stmt->setArrowResult(schemaRoot, columnDefs);
+ // Extract schema and create column descriptors
+ SQLRETURN ret = populateColumnDescriptors(schemaRoot, stmt);
+ if (!SQL_SUCCEEDED(ret)) {
+ return ret;
+ }
- // Clean up schema root
+ // Process the Arrow data
+ ret = stmt->setArrowResult(schemaRoot, stmt->resultColumns);
+ jclass schemaRootClass = env->GetObjectClass(schemaRoot);
+ jmethodID closeMethod = env->GetMethodID(schemaRootClass, "close", "()V");
+ if (closeMethod != nullptr) {
+ env->CallVoidMethod(schemaRoot, closeMethod);
+ }
env->DeleteLocalRef(schemaRoot);
-
- return result;
-
+ return ret;
} catch (const std::exception& e) {
- LOGF("Exception in GetColumns: %s", e.what());
- return SQL_ERROR;
- } catch (...) {
- LOG("Unknown exception in GetColumns");
+ LOGF("Exception in %s: %s", methodName, e.what());
+ if (env->ExceptionCheck()) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ }
return SQL_ERROR;
}
}
+SQLSMALLINT Connection::mapArrowTypeToSQL(JNIEnv* env, jobject arrowType) {
+ jclass typeClass = env->GetObjectClass(arrowType);
+ jmethodID getTypeIDMethod = env->GetMethodID(typeClass, "getTypeID",
+ "()Lorg/apache/arrow/vector/types/pojo/ArrowType$ArrowTypeID;");
+ jobject typeId = env->CallObjectMethod(arrowType, getTypeIDMethod);
+
+ jclass enumClass = env->GetObjectClass(typeId);
+ jmethodID nameMethod = env->GetMethodID(enumClass, "name", "()Ljava/lang/String;");
+ auto typeName = reinterpret_cast(env->CallObjectMethod(typeId, nameMethod));
+
+ const char* typeNameStr = env->GetStringUTFChars(typeName, nullptr);
+ SQLSMALLINT sqlType;
+
+ // Map Arrow types to SQL types
+ if (strcmp(typeNameStr, "Int") == 0) sqlType = SQL_INTEGER;
+ else if (strcmp(typeNameStr, "FloatingPoint") == 0) sqlType = SQL_DOUBLE;
+ // else if (strcmp(typeNameStr, "Utf8") == 0) sqlType = SQL_VARCHAR;
+ else if (strcmp(typeNameStr, "Bool") == 0) sqlType = SQL_BIT;
+ else if (strcmp(typeNameStr, "Date") == 0) sqlType = SQL_TYPE_DATE;
+ else if (strcmp(typeNameStr, "Time") == 0) sqlType = SQL_TYPE_TIME;
+ else if (strcmp(typeNameStr, "Timestamp") == 0) sqlType = SQL_TYPE_TIMESTAMP;
+ else if (strcmp(typeNameStr, "Decimal") == 0) sqlType = SQL_DECIMAL;
+ else if (strcmp(typeNameStr, "Binary") == 0) sqlType = SQL_BINARY;
+ else sqlType = SQL_VARCHAR; // Default fallback
+
+ env->ReleaseStringUTFChars(typeName, typeNameStr);
+ env->DeleteLocalRef(typeName);
+ env->DeleteLocalRef(typeId);
+ env->DeleteLocalRef(enumClass);
+ env->DeleteLocalRef(typeClass);
+
+ return sqlType;
+}
+
+SQLULEN Connection::getSQLTypeSize(SQLSMALLINT sqlType) {
+ switch (sqlType) {
+ case SQL_INTEGER: return sizeof(SQLINTEGER);
+ case SQL_SMALLINT: return sizeof(SQLSMALLINT);
+ case SQL_BIGINT: return sizeof(SQLBIGINT);
+ case SQL_DOUBLE: return sizeof(SQLDOUBLE);
+ case SQL_REAL: return sizeof(SQLREAL);
+ case SQL_DECIMAL: return 38; // Max precision
+ case SQL_BIT: return 1;
+ case SQL_TINYINT: return sizeof(SQLSCHAR);
+ case SQL_TYPE_DATE: return SQL_DATE_LEN;
+ case SQL_TYPE_TIME: return SQL_TIME_LEN;
+ case SQL_TYPE_TIMESTAMP: return SQL_TIMESTAMP_LEN;
+ case SQL_BINARY:
+ case SQL_VARBINARY: return 8000; // Max binary length
+ case SQL_VARCHAR:
+ case SQL_CHAR: return 8000; // Max string length
+ case SQL_WVARCHAR:
+ case SQL_WCHAR: return 4000; // Max Unicode string length (in characters)
+ default: return 8000; // Default to max string length
+ }
+}
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp
index a6dcfa0..9151650 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp
@@ -16,7 +16,7 @@ BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserv
{
switch (ul_reason_for_call) {
case DLL_PROCESS_ATTACH:
- dllPath = GetModuleDirectory(hModule);
+ dllPath = GetModuleDirectory();
break;
case DLL_THREAD_ATTACH:
case DLL_THREAD_DETACH:
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp
index 81d0f20..5b2bfb6 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp
@@ -1,5 +1,7 @@
// First, include Windows headers
+#include
#include
+#define WIN32_LEAN_AND_MEAN
// Then SQL headers in this specific order
#include
@@ -13,38 +15,12 @@
// Your project headers
#include "../include/statement.hpp"
+#include
+#include
+
#include "connection.hpp"
#include "../include/logging.hpp"
-static const std::vector TABLE_COLUMNS = {
- {"TABLE_CAT", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE},
- {"TABLE_SCHEM", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE},
- {"TABLE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS},
- {"TABLE_TYPE", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS},
- {"REMARKS", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE}
-};
-
-static const std::vector COLUMN_COLUMNS = {
- {"TABLE_CAT", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE},
- {"TABLE_SCHEM", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE},
- {"TABLE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS},
- {"COLUMN_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS},
- {"DATA_TYPE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS},
- {"TYPE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS},
- {"COLUMN_SIZE", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE},
- {"BUFFER_LENGTH", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE},
- {"DECIMAL_DIGITS", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE},
- {"NUM_PREC_RADIX", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE},
- {"NULLABLE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS},
- {"REMARKS", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE},
- {"COLUMN_DEF", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE},
- {"SQL_DATA_TYPE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS},
- {"SQL_DATETIME_SUB", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE},
- {"CHAR_OCTET_LENGTH", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE},
- {"ORDINAL_POSITION", 0, SQL_INTEGER, 10, 0, SQL_NO_NULLS},
- {"IS_NULLABLE", 0, SQL_VARCHAR, 3, 0, SQL_NO_NULLS}
-};
-
SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vector& columnDescriptors) {
LOGF("Starting setArrowResult with %zu columns", columnDescriptors.size());
@@ -55,21 +31,58 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->GetObjectClass(schemaRoot);
+ if (!rootClass) {
+ LOG("Failed to get root class");
+ return SQL_ERROR;
+ }
jmethodID getRowCountMethod = conn->env->GetMethodID(rootClass, "getRowCount", "()I");
+ if (!getRowCountMethod) {
+ LOG("Failed to get getRowCount method");
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jmethodID getFieldVectorsMethod = conn->env->GetMethodID(rootClass, "getFieldVectors", "()Ljava/util/List;");
+ if (!getFieldVectorsMethod) {
+ LOG("Failed to get getFieldVectors method");
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jint rowCount = conn->env->CallIntMethod(schemaRoot, getRowCountMethod);
LOGF("Row count: %d", rowCount);
jobject vectorsList = conn->env->CallObjectMethod(schemaRoot, getFieldVectorsMethod);
+ if (!vectorsList) {
+ LOG("Failed to get vector list");
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jclass listClass = conn->env->GetObjectClass(vectorsList);
+ if (!listClass) {
+ LOG("Failed to get list class");
+ conn->env->DeleteLocalRef(vectorsList);
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jmethodID getMethod = conn->env->GetMethodID(listClass, "get", "(I)Ljava/lang/Object;");
+ if (!getMethod) {
+ LOG("Failed to get get method");
+ conn->env->DeleteLocalRef(listClass);
+ conn->env->DeleteLocalRef(vectorsList);
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jmethodID sizeMethod = conn->env->GetMethodID(listClass, "size", "()I");
+ if (!sizeMethod) {
+ LOG("Failed to get size method");
+ conn->env->DeleteLocalRef(listClass);
+ conn->env->DeleteLocalRef(vectorsList);
+ conn->env->DeleteLocalRef(rootClass);
+ return SQL_ERROR;
+ }
jint vectorCount = conn->env->CallIntMethod(vectorsList, sizeMethod);
LOGF("Vector count: %d", vectorCount);
@@ -90,14 +103,32 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->GetObjectClass(fieldVector);
+ if (!vectorClass) {
+ LOGF("Failed to get vector class for column %d", col);
+ conn->env->DeleteLocalRef(fieldVector);
+ continue;
+ }
jmethodID getObjectMethod = conn->env->GetMethodID(vectorClass, "getObject", "(I)Ljava/lang/Object;");
+ if (!getObjectMethod) {
+ LOGF("Failed to get getObject method for column %d", col);
+ conn->env->DeleteLocalRef(vectorClass);
+ conn->env->DeleteLocalRef(fieldVector);
+ continue;
+ }
jmethodID isNullMethod = conn->env->GetMethodID(vectorClass, "isNull", "(I)Z");
+ if (!isNullMethod) {
+ LOGF("Failed to get isNull method for column %d", col);
+ conn->env->DeleteLocalRef(vectorClass);
+ conn->env->DeleteLocalRef(fieldVector);
+ continue;
+ }
for (jint row = 0; row < rowCount; row++) {
// Check if the value is null
jboolean isNull = conn->env->CallBooleanMethod(fieldVector, isNullMethod, row);
if (isNull || conn->env->ExceptionCheck()) {
if (conn->env->ExceptionCheck()) {
+ conn->env->ExceptionDescribe();
conn->env->ExceptionClear();
}
LOGF("Null value at row %d, col %d", row, col);
@@ -117,12 +148,21 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->FindClass("java/lang/String");
+ if (!stringClass) {
+ LOGF("Failed to find String class at row %d, col %d", row, col);
+ conn->env->DeleteLocalRef(value);
+ continue;
+ }
jmethodID toStringMethod = conn->env->GetMethodID(stringClass, "toString", "()Ljava/lang/String;");
- jstring strValue = (jstring)conn->env->CallObjectMethod(value, toStringMethod);
+ if (!toStringMethod) {
+ LOGF("Failed to get toString method at row %d, col %d", row, col);
+ conn->env->DeleteLocalRef(stringClass);
+ conn->env->DeleteLocalRef(value);
+ continue;
+ }
- if (strValue) {
- const char* chars = conn->env->GetStringUTFChars(strValue, nullptr);
- if (chars) {
+ if (auto strValue = reinterpret_cast(conn->env->CallObjectMethod(value, toStringMethod))) {
+ if (const char* chars = conn->env->GetStringUTFChars(strValue, nullptr)) {
resultData[row][col].isNull = false;
resultData[row][col].data = chars;
conn->env->ReleaseStringUTFChars(strValue, chars);
@@ -154,15 +194,24 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->ExceptionCheck()) {
conn->env->ExceptionDescribe();
conn->env->ExceptionClear();
}
clearResults();
return SQL_ERROR;
+ } catch (...) {
+ LOG("Unknown exception in setArrowResult");
+ if (conn != nullptr && conn->env != nullptr) {
+ if (conn->env->ExceptionCheck()) {
+ conn->env->ExceptionDescribe();
+ conn->env->ExceptionClear();
+ }
+ }
+ clearResults();
+ return SQL_ERROR;
}
}
@@ -197,57 +246,18 @@ bool Statement::hasData() const {
return hasResult && currentRow < resultData.size();
}
-std::vector Statement::setupColumnResultColumns() {
- // Clear any existing result set
- clearResults();
-
- // Create column descriptors for the COLUMNS result set
- resultColumns.clear();
- for (const auto& colDef : COLUMN_COLUMNS) {
- ColumnDesc col{};
- col.name = colDef.name;
- col.nameLength = static_cast(strlen(colDef.name));
- col.sqlType = colDef.sqlType;
- col.columnSize = colDef.columnSize;
- col.decimalDigits = colDef.decimalDigits;
- col.nullable = colDef.nullable;
- resultColumns.push_back(col);
- }
-
- LOGF("Set up %zu column result columns", resultColumns.size());
- return COLUMN_COLUMNS;
-}
-
-Statement::Statement(Connection* connection) : conn(connection) {
+Statement::Statement(Connection *connection) : conn(connection) {
hasResult = false;
currentRow = 0;
resultData.clear();
}
-std::vector Statement::setupTableResultColumns() {
- clearResults();
- resultColumns.clear();
-
- for (const auto& colDef : TABLE_COLUMNS) {
- ColumnDesc col{};
- col.name = colDef.name;
- col.nameLength = static_cast(strlen(colDef.name));
- col.sqlType = colDef.sqlType;
- col.columnSize = colDef.columnSize;
- col.decimalDigits = colDef.decimalDigits;
- col.nullable = colDef.nullable;
- LOGF("Setting up column %s with SQL type %d", col.name, col.sqlType);
- resultColumns.push_back(col);
- }
- return TABLE_COLUMNS;
-}
-
// Additional helper method for fetching data from the result set
SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
- SQLPOINTER targetValue, SQLLEN bufferLength,
- SQLLEN* strLengthOrIndicator) {
+ SQLPOINTER targetValue, SQLLEN bufferLength,
+ SQLLEN *strLengthOrIndicator) {
LOGF("getData called for column %d", colNum);
-
+
// Validate state and parameters
if (!hasResult || currentRow == 0 || currentRow > resultData.size() ||
colNum == 0 || colNum > resultColumns.size()) {
@@ -255,7 +265,7 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
return SQL_ERROR;
}
- const auto& colData = resultData[currentRow - 1][colNum - 1];
+ const auto &colData = resultData[currentRow - 1][colNum - 1];
LOGF("Fetching data for row %d, column %d", currentRow - 1, colNum - 1);
// Handle NULL values
@@ -271,11 +281,11 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
switch (targetType) {
case SQL_C_WCHAR: {
LOGF("Converting to WCHAR: '%s'", colData.data.c_str());
-
+
const int requiredSize = MultiByteToWideChar(
- CP_UTF8, 0, colData.data.c_str(), -1, nullptr, 0
- ) * sizeof(WCHAR);
-
+ CP_UTF8, 0, colData.data.c_str(), -1, nullptr, 0
+ ) * sizeof(WCHAR);
+
if (strLengthOrIndicator) {
*strLengthOrIndicator = requiredSize - sizeof(WCHAR);
}
@@ -301,7 +311,7 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
// Check for truncation
if (charsWritten == maxChars) {
- static_cast(targetValue)[maxChars - 1] = L'\0';
+ static_cast(targetValue)[maxChars - 1] = L'\0';
return SQL_SUCCESS_WITH_INFO;
}
@@ -315,8 +325,200 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType,
}
void Statement::clearResults() {
+ LOG("Called clearResults()");
hasResult = false;
currentRow = 0;
resultData.clear();
- resultColumns.clear();
+ boundParams.clear();
+}
+
+SQLRETURN Statement::bindParameter(SQLUSMALLINT parameterNumber,
+ SQLSMALLINT inputOutputType,
+ SQLSMALLINT valueType,
+ SQLSMALLINT parameterType,
+ SQLULEN columnSize,
+ SQLSMALLINT decimalDigits,
+ SQLPOINTER parameterValuePtr,
+ SQLLEN bufferLength,
+ SQLLEN *strLen_or_IndPtr) {
+ LOGF("Binding parameter %d of type %d", parameterNumber, valueType);
+
+ if (!conn || !conn->env) {
+ LOG("Invalid connection or environment");
+ return SQL_ERROR;
+ }
+
+ // Parameter numbers are 1-based
+ if (parameterNumber < 1) {
+ LOG("Invalid parameter number");
+ return SQL_ERROR;
+ }
+
+ // Check for null indicator
+ if (strLen_or_IndPtr && *strLen_or_IndPtr == SQL_NULL_DATA) {
+ // Handle NULL parameter - could use a special JniParam constructor for NULL
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ // You might want to add a setNull method to JniParam
+ return SQL_SUCCESS;
+ }
+
+ try {
+ // Convert ODBC parameter to JniParam based on valueType
+ switch (valueType) {
+ case SQL_C_CHAR: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ std::string value(static_cast(parameterValuePtr));
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(value);
+ break;
+ }
+
+ case SQL_C_WCHAR: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ wchar_t* wstr = static_cast(parameterValuePtr);
+ // Convert wide string to UTF-8
+ int requiredSize = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr);
+ if (requiredSize == 0) return SQL_ERROR;
+
+ std::string utf8str(requiredSize, '\0');
+ if (WideCharToMultiByte(CP_UTF8, 0, wstr, -1, &utf8str[0], requiredSize, nullptr, nullptr) == 0) {
+ return SQL_ERROR;
+ }
+ utf8str.resize(strlen(utf8str.c_str())); // Remove trailing null
+
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(utf8str);
+ break;
+ }
+
+ case SQL_C_LONG:
+ case SQL_C_SLONG: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ int value = *static_cast(parameterValuePtr);
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(value);
+ break;
+ }
+
+ case SQL_C_FLOAT: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ float value = *static_cast(parameterValuePtr);
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(value);
+ break;
+ }
+
+ case SQL_C_DOUBLE: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ double value = *static_cast(parameterValuePtr);
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(value);
+ break;
+ }
+
+ case SQL_C_BIT: {
+ if (!parameterValuePtr) return SQL_ERROR;
+ bool value = (*static_cast(parameterValuePtr)) != 0;
+ if (parameterNumber > boundParams.size()) {
+ boundParams.resize(parameterNumber);
+ }
+ boundParams[parameterNumber - 1] = JniParam(value);
+ break;
+ }
+
+ default:
+ LOGF("Unsupported parameter type: %d", valueType);
+ return SQL_ERROR;
+ }
+
+ return SQL_SUCCESS;
+ }
+ catch (const std::exception& e) {
+ LOGF("Exception in bindParameter: %s", e.what());
+ return SQL_ERROR;
+ }
}
+
+std::string Statement::escapeString(const std::string& str) const {
+ std::string escaped;
+ escaped.reserve(str.length() + str.length()/8); // Reserve extra space for escapes
+
+ for (char c : str) {
+ switch (c) {
+ case '\'': escaped += "''"; break; // Double single quotes for SQL
+ default: escaped += c; break;
+ }
+ }
+ return escaped;
+}
+
+std::string Statement::buildInterpolatedQuery() const {
+ std::string result = originalQuery;
+
+ // Find all ? parameters and replace them
+ size_t paramIndex = 0;
+ size_t pos = 0;
+
+ while ((pos = result.find('?', pos)) != std::string::npos) {
+ if (paramIndex >= boundParams.size()) {
+ throw std::runtime_error("Not enough parameters bound for query");
+ }
+
+ const auto& param = boundParams[paramIndex];
+ std::string replacement;
+
+ // Convert parameter to string representation
+ switch (param.getType()) {
+ case JniParam::Type::String:
+ replacement = "'" + escapeString(param.getString()) + "'";
+ break;
+
+ case JniParam::Type::StringArray: {
+ replacement = "(";
+ bool first = true;
+ for (const auto& str : param.getStringArray()) {
+ if (!first) replacement += ",";
+ replacement += "'" + escapeString(str) + "'";
+ first = false;
+ }
+ replacement += ")";
+ break;
+ }
+
+ case JniParam::Type::Integer:
+ replacement = std::to_string(param.getInt());
+ break;
+
+ case JniParam::Type::Float:
+ replacement = std::to_string(param.getFloat());
+ break;
+
+ case JniParam::Type::Double:
+ replacement = std::to_string(param.getDouble());
+ break;
+
+ case JniParam::Type::Boolean:
+ replacement = param.getBool() ? "1" : "0";
+ break;
+ }
+
+ result.replace(pos, 1, replacement);
+ pos += replacement.length();
+ paramIndex++;
+ }
+
+ LOGF("Interpolated query: %s", result.c_str());
+ return result;
+}
\ No newline at end of file
diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs b/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs
index f27d29d..3e1e37d 100755
--- a/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs
+++ b/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs
@@ -129,7 +129,26 @@ static void TestBasicQueries(OdbcConnection conn)
if (tables.Rows.Count > 0)
{
string tableName = tables.Rows[0]["TABLE_NAME"].ToString();
- TestMetadata(conn, "Basic Select", $"SELECT * FROM {tableName} LIMIT 5");
+ using var cmd = conn.CreateCommand();
+ cmd.CommandText = $"SELECT * FROM \"{tableName}\" LIMIT 5";
+ try
+ {
+ using var reader = cmd.ExecuteReader();
+ var columns = Enumerable.Range(0, reader.FieldCount)
+ .Select(i => reader.GetName(i));
+ Console.WriteLine(string.Join("\t", columns));
+
+ while (reader.Read())
+ {
+ var values = Enumerable.Range(0, reader.FieldCount)
+ .Select(i => reader[i]?.ToString() ?? "NULL");
+ Console.WriteLine(string.Join("\t", values));
+ }
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine($"Parameterized query not supported: {ex.Message}");
+ }
}
}
@@ -177,14 +196,16 @@ static void TestMetadataRetrieval(OdbcConnection conn)
cmd.CommandText = $"SELECT * FROM {tableName} LIMIT 1";
using var reader = cmd.ExecuteReader();
- for (int i = 0; i < reader.FieldCount; i++)
- {
- Console.WriteLine($"Column {i}:");
- Console.WriteLine($" Name: {reader.GetName(i)}");
- Console.WriteLine($" Type: {reader.GetFieldType(i)}");
- Console.WriteLine($" Precision: {reader.GetSchemaTable().Rows[i]["NumericPrecision"]}");
- Console.WriteLine($" Scale: {reader.GetSchemaTable().Rows[i]["NumericScale"]}");
- Console.WriteLine($" IsNullable: {reader.GetSchemaTable().Rows[i]["AllowDBNull"]}");
+ while(reader.Read()) {
+ for (int i = 0; i < reader.FieldCount; i++)
+ {
+ Console.WriteLine($"Column {i}:");
+ Console.WriteLine($" Name: {reader.GetName(i)}");
+ Console.WriteLine($" Type: {reader.GetFieldType(i)}");
+ Console.WriteLine($" Precision: {reader.GetSchemaTable().Rows[i]["NumericPrecision"]}");
+ Console.WriteLine($" Scale: {reader.GetSchemaTable().Rows[i]["NumericScale"]}");
+ Console.WriteLine($" IsNullable: {reader.GetSchemaTable().Rows[i]["AllowDBNull"]}");
+ }
}
}
}
diff --git a/calcite-rs-jni/odbc/build-log.txt b/calcite-rs-jni/odbc/build-log.txt
index 3f5378a..fc52de3 100755
Binary files a/calcite-rs-jni/odbc/build-log.txt and b/calcite-rs-jni/odbc/build-log.txt differ