diff --git a/DBChatPro.UI/Components/Pages/ConnectDb.razor b/DBChatPro.UI/Components/Pages/ConnectDb.razor index ef4a32f..2c80978 100644 --- a/DBChatPro.UI/Components/Pages/ConnectDb.razor +++ b/DBChatPro.UI/Components/Pages/ConnectDb.razor @@ -32,6 +32,7 @@ MYSQL POSTGRESQL ORACLE + SNOWFLAKE diff --git a/DBChatPro.UI/Program.cs b/DBChatPro.UI/Program.cs index 67eebd9..ea894b2 100644 --- a/DBChatPro.UI/Program.cs +++ b/DBChatPro.UI/Program.cs @@ -24,6 +24,7 @@ builder.Services.AddScoped(); builder.Services.AddScoped(); builder.Services.AddScoped(); +builder.Services.AddScoped(); if (!string.IsNullOrEmpty(builder.Configuration["AWS:Profile"])) { diff --git a/DBChatPro.UI/SampleConnectionStrings.md b/DBChatPro.UI/SampleConnectionStrings.md index 679d78b..98346b3 100644 --- a/DBChatPro.UI/SampleConnectionStrings.md +++ b/DBChatPro.UI/SampleConnectionStrings.md @@ -9,3 +9,6 @@ Host=127.0.0.1;Port=5432;Database=sakila;Username=sakila;Password=p_ssW0rd; User Id=;Password=;Data Source=:/; + + +account=;user=;password=;db=;schema=;warehouse=; diff --git a/DbChatPro.Core/DbChatPro.Core.csproj b/DbChatPro.Core/DbChatPro.Core.csproj index d4bedc8..4789d17 100644 --- a/DbChatPro.Core/DbChatPro.Core.csproj +++ b/DbChatPro.Core/DbChatPro.Core.csproj @@ -29,5 +29,6 @@ + diff --git a/DbChatPro.Core/Services/DatabaseManagerService.cs b/DbChatPro.Core/Services/DatabaseManagerService.cs index 48770ba..591d21f 100644 --- a/DbChatPro.Core/Services/DatabaseManagerService.cs +++ b/DbChatPro.Core/Services/DatabaseManagerService.cs @@ -11,7 +11,8 @@ public class DatabaseManagerService( MySqlDatabaseService mySqlDb, SqlServerDatabaseService msSqlDb, PostgresDatabaseService postgresDb, - OracleDatabaseService oracleDb) : IDatabaseService + OracleDatabaseService oracleDb, + SnowflakeDatabaseService snowflakeDb) : IDatabaseService { public async Task>> GetDataTable(AIConnection conn, string sqlQuery) { @@ -25,6 +26,8 @@ public async Task>> GetDataTable(AIConnection conn, string sql return await postgresDb.GetDataTable(conn, sqlQuery); case "ORACLE": return await oracleDb.GetDataTable(conn, sqlQuery); + case "SNOWFLAKE": + return await snowflakeDb.GetDataTable(conn, sqlQuery); } return null; @@ -42,6 +45,8 @@ public async Task GenerateSchema(AIConnection conn) return await postgresDb.GenerateSchema(conn); case "ORACLE": return await oracleDb.GenerateSchema(conn); + case "SNOWFLAKE": + return await snowflakeDb.GenerateSchema(conn); } return new() { SchemaStructured = new List(), SchemaRaw = new List() }; diff --git a/DbChatPro.Core/Services/SnowflakeDatabaseService.cs b/DbChatPro.Core/Services/SnowflakeDatabaseService.cs new file mode 100644 index 0000000..a73f88f --- /dev/null +++ b/DbChatPro.Core/Services/SnowflakeDatabaseService.cs @@ -0,0 +1,119 @@ +// Generated by AI - Snowflake database service implementation +using DBChatPro.Models; +using Snowflake.Data.Client; + +namespace DBChatPro +{ + public class SnowflakeDatabaseService : IDatabaseService + { + public async Task>> GetDataTable(AIConnection conn, string sqlQuery) + { + var rows = new List>(); + using var connection = new SnowflakeDbConnection(conn.ConnectionString); + + await connection.OpenAsync(); + + using var command = new SnowflakeDbCommand(sqlQuery, connection); + using var reader = await command.ExecuteReaderAsync(); + + int count = 0; + bool headersAdded = false; + while (await reader.ReadAsync()) + { + var cols = new List(); + var headerCols = new List(); + if (!headersAdded) + { + for (int i = 0; i < reader.FieldCount; i++) + { + headerCols.Add(reader.GetName(i).ToString()); + } + headersAdded = true; + rows.Add(headerCols); + } + + for (int i = 0; i <= reader.FieldCount - 1; i++) + { + try + { + cols.Add(reader.GetValue(i).ToString()); + } + catch + { + cols.Add("DataTypeConversionError"); + } + } + rows.Add(cols); + } + + return rows; + } + + public async Task GenerateSchema(AIConnection conn) + { + var dbSchema = new DatabaseSchema() { SchemaRaw = new List(), SchemaStructured = new List() }; + List> rows = new(); + + // Parse connection string for database and schema + var pairs = conn.ConnectionString.Split(";"); + var database = pairs.Where(x => x.ToUpper().Contains("DB=") || x.ToUpper().Contains("DATABASE=")).FirstOrDefault()?.Split("=").Last() ?? + pairs.Where(x => x.ToUpper().Contains("DB") && x.Contains("=")).FirstOrDefault()?.Split("=").Last() ?? "PUBLIC"; + var schema = pairs.Where(x => x.ToUpper().Contains("SCHEMA=")).FirstOrDefault()?.Split("=").Last() ?? "PUBLIC"; + + // Snowflake schema query using INFORMATION_SCHEMA + string sqlQuery = @"SELECT + table_name, + column_name + FROM + information_schema.columns + WHERE + table_catalog = ? + AND table_schema = ? + ORDER BY + table_name, + column_name"; + + using var connection = new SnowflakeDbConnection(conn.ConnectionString); + await connection.OpenAsync(); + + using var command = new SnowflakeDbCommand(sqlQuery, connection); + command.Parameters.Add(new SnowflakeDbParameter("1", database)); + command.Parameters.Add(new SnowflakeDbParameter("2", schema)); + + using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + rows.Add(new KeyValuePair(reader.GetString(0), reader.GetString(1))); + } + + var groups = rows.GroupBy(x => x.Key); + + foreach (var group in groups) + { + dbSchema.SchemaStructured.Add(new TableSchema() { TableName = group.Key, Columns = group.Select(x => x.Value).ToList() }); + } + + var textLines = new List(); + + foreach (var table in dbSchema.SchemaStructured) + { + var schemaLine = $"- {table.TableName} ("; + + foreach (var column in table.Columns) + { + schemaLine += column + ", "; + } + + schemaLine += ")"; + schemaLine = schemaLine.Replace(", )", " )"); + + Console.WriteLine(schemaLine); + textLines.Add(schemaLine); + } + + dbSchema.SchemaRaw = textLines; + + return dbSchema; + } + } +} \ No newline at end of file