Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cli/nao_core/commands/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def check_database_connection(db_config: AnyDatabaseConfig) -> tuple[bool, str]:
Tuple of (success, message)
"""
try:
if db_config.type == "redshift":
schemas = db_config.list_schemas()
schema_count = len(schemas)
return True, f"Connected successfully ({schema_count} schemas found)"

conn = db_config.connect()
# Run a simple query to verify the connection works
if hasattr(db_config, "dataset_id") and db_config.dataset_id:
Expand Down
7 changes: 2 additions & 5 deletions cli/nao_core/commands/sync/providers/databases/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def sync_redshift(
db_path = base_path / "type=redshift" / f"database={db_name}"
state = DatabaseSyncState(db_path=db_path)

if db_config.schema_name:
schemas = [db_config.schema_name]
else:
schemas = conn.list_databases()
schemas = db_config.list_schemas()

schema_task = progress.add_task(
f"[dim]{db_config.name}[/dim]",
Expand All @@ -40,7 +37,7 @@ def sync_redshift(

for schema in schemas:
try:
all_tables = conn.list_tables(database=schema)
all_tables = db_config.list_tables(schema)
except Exception:
progress.update(schema_task, advance=1)
continue
Expand Down
56 changes: 44 additions & 12 deletions cli/nao_core/config/databases/redshift.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pathlib import Path
from typing import Literal

import ibis
from ibis import BaseBackend
import redshift_connector
from pydantic import BaseModel, Field
from sshtunnel import SSHTunnelForwarder

Expand Down Expand Up @@ -31,7 +30,6 @@ class RedshiftConfig(DatabaseConfig):
database: str = Field(description="Database name")
user: str = Field(description="Username")
password: str = Field(description="Password")
schema_name: str | None = Field(default=None, description="Default schema (optional, uses 'public' if not set)")
sslmode: str = Field(default="require", description="SSL mode for the connection")
ssh_tunnel: RedshiftSSHTunnelConfig | None = Field(default=None, description="SSH tunnel configuration (optional)")

Expand All @@ -49,7 +47,6 @@ def promptConfig(cls) -> "RedshiftConfig":
user = ask_text("Username:", required_field=True)
password = ask_text("Password:", password=True, required_field=True)
sslmode = ask_text("SSL mode:", default="require") or "require"
schema_name = ask_text("Default schema (uses 'public' if empty):")

use_ssh = ask_confirm("Use SSH tunnel?", default=False)
ssh_tunnel = None
Expand Down Expand Up @@ -80,13 +77,12 @@ def promptConfig(cls) -> "RedshiftConfig":
database=database or "",
user=user or "",
password=password or "",
schema_name=schema_name,
sslmode=sslmode,
ssh_tunnel=ssh_tunnel,
)

def connect(self) -> BaseBackend:
"""Create an Ibis Redshift connection."""
def connect(self):
"""Create a Redshift connection via redshift_connector."""

# Determine connection host and port
connect_host = self.host
Expand Down Expand Up @@ -116,17 +112,53 @@ def connect(self) -> BaseBackend:
"database": self.database,
"user": self.user,
"password": self.password,
"client_encoding": "utf8",
"sslmode": self.sslmode,
}

if self.schema_name:
kwargs["schema"] = self.schema_name

return ibis.postgres.connect(
return redshift_connector.connect(
**kwargs,
)

def list_schemas(self) -> list[str]:
"""List schemas in the Redshift database."""
conn = self.connect()
cursor = conn.cursor()

cursor.execute("""
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%'
AND nspname != 'information_schema'
ORDER BY nspname;
""")

schemas = [row[0] for row in cursor.fetchall()]
cursor.close()
conn.close()

return schemas

def list_tables(self, schema: str) -> list[str]:
"""List tables in a specific schema."""
conn = self.connect()
cursor = conn.cursor()

cursor.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = %s
ORDER BY tablename;
""",
(schema,),
)

tables = [row[0] for row in cursor.fetchall()]
cursor.close()
conn.close()

return tables

def get_database_name(self) -> str:
"""Get the database name for Redshift."""

Expand Down
1 change: 1 addition & 0 deletions cli/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"mistralai>=1.11.1",
"google-genai>=1.61.0",
"sshtunnel>=0.4.0",
"redshift-connector>=2.1.11",
]

[project.optional-dependencies]
Expand Down
Loading
Loading