Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fakesnow/conn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from collections import OrderedDict
from collections.abc import Iterable
from pathlib import Path
from types import TracebackType
Expand All @@ -23,7 +24,7 @@ class FakeSnowflakeConnection:
def __init__(
self,
duck_conn: DuckDBPyConnection,
results_cache: dict[str, tuple],
results_cache: OrderedDict[str, tuple],
database: str | None = None,
schema: str | None = None,
create_database: bool = True,
Expand Down
5 changes: 3 additions & 2 deletions fakesnow/copy_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _get_table_columns(


def _params(expr: exp.Copy, params: MutableParams | None = None) -> CopyParams:
kwargs = {}
kwargs: dict[str, Any] = {}
force = False
purge = False
on_error = "ABORT_STATEMENT"
Expand Down Expand Up @@ -381,7 +381,7 @@ def _inserts(

if parquet_info and parquet_info.is_single_variant:
# Single VARIANT column: convert entire parquet row to JSON
inserts = []
inserts: list[exp.Expression] = []
for url in urls:
parquet_col_names = parquet_info.parquet_columns[url]

Expand Down Expand Up @@ -543,6 +543,7 @@ def read_expression(self, url: str) -> exp.Expression: ...

@staticmethod
def make_eq(name: str, value: list | str | int | bool) -> exp.EQ:
expression: exp.Expression
if isinstance(value, list):
expression = exp.array(*[exp.Literal(this=str(v), is_string=isinstance(v, str)) for v in value])
elif isinstance(value, bool):
Expand Down
38 changes: 25 additions & 13 deletions fakesnow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from fakesnow import logger
from fakesnow.copy_into import copy_into
from fakesnow.params import MutableParams
from fakesnow.rowtype import describe_as_result_metadata
from fakesnow.rowtype import describe_as_result_metadata, describe_as_rowtype
from fakesnow.transforms import stage

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,15 +81,15 @@ def __init__(
self._conn = conn
self._duck_conn = duck_conn
self._use_dict_result = use_dict_result
self._last_sql = None
self._last_params = None
self._last_transformed = None
self._sqlstate = None
self._last_sql: str | None = None
self._last_params: MutableParams | None = None
self._last_transformed: exp.Expression | None = None
self._sqlstate: str | None = None
self._arraysize = 1
self._arrow_table = None
self._arrow_table_fetch_index = None
self._rowcount = None
self._sfqid = None
self._arrow_table: pyarrow.Table | None = None
self._arrow_table_fetch_index: int | None = None
self._rowcount: int | None = None
self._sfqid: str | None = None
self._converter = snowflake.connector.converter.SnowflakeConverter()
self._prefetch_hook: Callable[[], None] | None = None

Expand Down Expand Up @@ -213,8 +213,8 @@ def prefetch() -> None:
"Cannot retrieve data on the status of this query. "
"No information returned from server for query '{}'"
)
# Restore the cached result data
self._arrow_table, self._rowcount, self._last_sql, self._last_params, self._last_transformed = value
# Restore the cached result data (6-tuple: arrow_table, rowcount, last_sql, last_params, last_transformed, rowtype)
self._arrow_table, self._rowcount, self._last_sql, self._last_params, self._last_transformed, _ = value
self._sfqid = sfqid
self._arrow_table_fetch_index = None
self._prefetch_hook = None
Expand Down Expand Up @@ -357,8 +357,8 @@ def _execute(self, transformed: exp.Expression, params: MutableParams | None = N
raise snowflake.connector.errors.ProgrammingError(
msg=f"Statement {sfqid} not found", errno=709, sqlstate="02000"
)
# Restore the cached result data
self._arrow_table, self._rowcount, self._last_sql, self._last_params, self._last_transformed = value
# Restore the cached result data (6-tuple: arrow_table, rowcount, last_sql, last_params, last_transformed, rowtype)
self._arrow_table, self._rowcount, self._last_sql, self._last_params, self._last_transformed, _ = value
self._sfqid = sfqid
self._arrow_table_fetch_index = None
return
Expand Down Expand Up @@ -554,12 +554,24 @@ def _execute(self, transformed: exp.Expression, params: MutableParams | None = N
self._last_params = None if result_sql else params
self._last_transformed = transformed

# Compute rowtype for cache (needed by GET endpoint to generate rowsetBase64)
# Skip for DESCRIBE queries to avoid infinite recursion
try:
if cmd not in {"DESCRIBE TABLE", "DESCRIBE VIEW", "DESCRIBE"}:
rowtype = describe_as_rowtype(self._describe_last_sql())
else:
rowtype = []
except Exception:
# If rowtype computation fails, use empty list as fallback
rowtype = []

self._conn.results_cache[self._sfqid] = (
self._arrow_table,
self._rowcount,
self._last_sql,
self._last_params,
self._last_transformed,
rowtype, # 6th element - required for result_scan() and get_results_from_sfqid()
)

def executemany(
Expand Down
37 changes: 36 additions & 1 deletion fakesnow/instance.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

import logging
import os
from collections import OrderedDict
from pathlib import Path
from typing import Any

import duckdb

import fakesnow.fakes as fakes
import fakesnow.macros as macros
from fakesnow import info_schema
from fakesnow.transforms import show

logger = logging.getLogger("fakesnow.instance")

GLOBAL_DATABASE_NAME = "_fs_global"


Expand All @@ -25,7 +31,7 @@ def __init__(
self.db_path = db_path
self.nop_regexes = nop_regexes

self.results_cache: dict[str, tuple] = {}
self.results_cache: OrderedDict[str, tuple] = OrderedDict()
self.duck_conn = duckdb.connect(database=":memory:")

# create a "global" database for storing objects which span databases.
Expand All @@ -37,6 +43,35 @@ def __init__(
# use UTC instead of local time zone for consistent testing
self.duck_conn.execute("SET GLOBAL TimeZone = 'UTC'")

# Attach existing database files from db_path for persistence across restarts
if self.db_path:
self._attach_existing_databases()

def _attach_existing_databases(self) -> None:
"""Scan db_path for existing .db files and attach them."""
db_path = Path(self.db_path) # type: ignore[arg-type]
if not db_path.is_dir():
logger.warning(f"db_path does not exist or is not a directory: {db_path}")
return

for db_file in db_path.glob("*.db"):
# Database name is the filename without .db extension (uppercase for Snowflake convention)
db_name = db_file.stem.upper()

# Skip if already attached
existing = self.duck_conn.execute(
f"SELECT COUNT(*) FROM information_schema.schemata WHERE upper(catalog_name) = '{db_name}'"
).fetchone()
logger.debug(f"Checking existing database {db_name}: [{existing}] attached")
if existing and existing[0] > 0:
logger.info(f"Database {db_name} already attached, skipping")
continue

logger.info(f"Attaching existing database: {db_name} from {db_file}")
self.duck_conn.execute(f"ATTACH DATABASE '{db_file}' AS {db_name}")
self.duck_conn.execute(info_schema.per_db_creation_sql(db_name))
self.duck_conn.execute(macros.creation_sql(db_name))

def connect(
self,
database: str | None = None,
Expand Down
Loading