Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DBChatPro.UI/Components/Pages/ConnectDb.razor
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
<MudSelectItem Value="@("MYSQL")" T="string">MYSQL</MudSelectItem>
<MudSelectItem Value="@("POSTGRESQL")" T="string">POSTGRESQL</MudSelectItem>
<MudSelectItem Value="@("ORACLE")" T="string">ORACLE</MudSelectItem>
<MudSelectItem Value="@("SNOWFLAKE")" T="string">SNOWFLAKE</MudSelectItem>
</MudSelect>
<MudTextField @bind-Value="aiConnection.Name" T="string" Label="Connection name" Variant="Variant.Text" />
<MudTextField @bind-Value="aiConnection.ConnectionString" T="string" Label="Connection string" Variant="Variant.Text" Lines="5" />
Expand Down
1 change: 1 addition & 0 deletions DBChatPro.UI/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
builder.Services.AddScoped<SqlServerDatabaseService>();
builder.Services.AddScoped<PostgresDatabaseService>();
builder.Services.AddScoped<OracleDatabaseService>();
builder.Services.AddScoped<SnowflakeDatabaseService>();

if (!string.IsNullOrEmpty(builder.Configuration["AWS:Profile"]))
{
Expand Down
3 changes: 3 additions & 0 deletions DBChatPro.UI/SampleConnectionStrings.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ Host=127.0.0.1;Port=5432;Database=sakila;Username=sakila;Password=p_ssW0rd;

<!-- Sample Oracle connection string -->
User Id=<your-username>;Password=<your-password>;Data Source=<host>:<port>/<service-name>;

<!-- Sample Snowflake connection string -->
account=<your-account>;user=<your-username>;password=<your-password>;db=<your-database>;schema=<your-schema>;warehouse=<your-warehouse>;
1 change: 1 addition & 0 deletions DbChatPro.Core/DbChatPro.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
<PackageReference Include="AWSSDK.Core" Version="3.7.402.66" />
<PackageReference Include="AWSSDK.SecurityToken" Version="3.7.401.109" />
<PackageReference Include="AWSSDK.Extensions.NETCore.Setup" Version="3.7.400" />
<PackageReference Include="Snowflake.Data" Version="4.1.0" />
</ItemGroup>
</Project>
7 changes: 6 additions & 1 deletion DbChatPro.Core/Services/DatabaseManagerService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public class DatabaseManagerService(
MySqlDatabaseService mySqlDb,
SqlServerDatabaseService msSqlDb,
PostgresDatabaseService postgresDb,
OracleDatabaseService oracleDb) : IDatabaseService
OracleDatabaseService oracleDb,
SnowflakeDatabaseService snowflakeDb) : IDatabaseService
{
public async Task<List<List<string>>> GetDataTable(AIConnection conn, string sqlQuery)
{
Expand All @@ -25,6 +26,8 @@ public async Task<List<List<string>>> GetDataTable(AIConnection conn, string sql
return await postgresDb.GetDataTable(conn, sqlQuery);
case "ORACLE":
return await oracleDb.GetDataTable(conn, sqlQuery);
case "SNOWFLAKE":
return await snowflakeDb.GetDataTable(conn, sqlQuery);
}

return null;
Expand All @@ -42,6 +45,8 @@ public async Task<DatabaseSchema> GenerateSchema(AIConnection conn)
return await postgresDb.GenerateSchema(conn);
case "ORACLE":
return await oracleDb.GenerateSchema(conn);
case "SNOWFLAKE":
return await snowflakeDb.GenerateSchema(conn);
}

return new() { SchemaStructured = new List<TableSchema>(), SchemaRaw = new List<string>() };
Expand Down
119 changes: 119 additions & 0 deletions DbChatPro.Core/Services/SnowflakeDatabaseService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Generated by AI - Snowflake database service implementation
using DBChatPro.Models;
using Snowflake.Data.Client;

namespace DBChatPro
{
public class SnowflakeDatabaseService : IDatabaseService
{
public async Task<List<List<string>>> GetDataTable(AIConnection conn, string sqlQuery)
{
var rows = new List<List<string>>();
using var connection = new SnowflakeDbConnection(conn.ConnectionString);

await connection.OpenAsync();

using var command = new SnowflakeDbCommand(sqlQuery, connection);
using var reader = await command.ExecuteReaderAsync();

int count = 0;
bool headersAdded = false;
while (await reader.ReadAsync())
{
var cols = new List<string>();
var headerCols = new List<string>();
if (!headersAdded)
{
for (int i = 0; i < reader.FieldCount; i++)
{
headerCols.Add(reader.GetName(i).ToString());
}
headersAdded = true;
rows.Add(headerCols);
}

for (int i = 0; i <= reader.FieldCount - 1; i++)
{
try
{
cols.Add(reader.GetValue(i).ToString());
}
catch
{
cols.Add("DataTypeConversionError");
}
}
rows.Add(cols);
}

return rows;
}

public async Task<DatabaseSchema> GenerateSchema(AIConnection conn)
{
var dbSchema = new DatabaseSchema() { SchemaRaw = new List<string>(), SchemaStructured = new List<TableSchema>() };
List<KeyValuePair<string, string>> rows = new();

// Parse connection string for database and schema
var pairs = conn.ConnectionString.Split(";");
var database = pairs.Where(x => x.ToUpper().Contains("DB=") || x.ToUpper().Contains("DATABASE=")).FirstOrDefault()?.Split("=").Last() ??
pairs.Where(x => x.ToUpper().Contains("DB") && x.Contains("=")).FirstOrDefault()?.Split("=").Last() ?? "PUBLIC";
var schema = pairs.Where(x => x.ToUpper().Contains("SCHEMA=")).FirstOrDefault()?.Split("=").Last() ?? "PUBLIC";

// Snowflake schema query using INFORMATION_SCHEMA
string sqlQuery = @"SELECT
table_name,
column_name
FROM
information_schema.columns
WHERE
table_catalog = ?
AND table_schema = ?
ORDER BY
table_name,
column_name";

using var connection = new SnowflakeDbConnection(conn.ConnectionString);
await connection.OpenAsync();

using var command = new SnowflakeDbCommand(sqlQuery, connection);
command.Parameters.Add(new SnowflakeDbParameter("1", database));
command.Parameters.Add(new SnowflakeDbParameter("2", schema));

using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
rows.Add(new KeyValuePair<string, string>(reader.GetString(0), reader.GetString(1)));
}

var groups = rows.GroupBy(x => x.Key);

foreach (var group in groups)
{
dbSchema.SchemaStructured.Add(new TableSchema() { TableName = group.Key, Columns = group.Select(x => x.Value).ToList() });
}

var textLines = new List<string>();

foreach (var table in dbSchema.SchemaStructured)
{
var schemaLine = $"- {table.TableName} (";

foreach (var column in table.Columns)
{
schemaLine += column + ", ";
}

schemaLine += ")";
schemaLine = schemaLine.Replace(", )", " )");

Console.WriteLine(schemaLine);
textLines.Add(schemaLine);
}

dbSchema.SchemaRaw = textLines;

return dbSchema;
}
}
}