Skip to content

Commit

Permalink
Add reset= to SqliterDB(), to drop all existing tables (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
seapagan authored Sep 27, 2024
1 parent b259161 commit cdd369e
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 4 deletions.
4 changes: 0 additions & 4 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

## General Plans and Ideas

- add an option to the SQLiter constructor to delete the database file if it
already exists. Default to False.
- add attributes to the BaseDBModel to read the table-name, file-name, is-memory
etc.
- add an 'execute' method to the main class to allow executing arbitrary SQL
queries which can be chained to the 'find_first' etc methods or just used
directly.
- add a method to drop the entire database easiest way is prob to just delete
and recreate the database file.
- add an 'exists_ok' (default True) parameter to the 'create_table' method so it
will raise an exception if the table already exists and this is set to False.
- add a `rollback` method to the main class to allow manual rollbacks.
Expand Down
12 changes: 12 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ exception will be raised.
> db = SqliterDB(":memory:")
> ```
#### Resetting the Database
If you want to reset the database when you create the SqliterDB object, you can
pass `reset=True`:
```python
db = SqliterDB("your_database.db", reset=True)
```
This will effectively drop all user tables from the database. The file itself is
not deleted, only the tables are dropped.

### Creating Tables

```python
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ target-version = "py39" # minimum python version supported
indent-style = "space"
quote-style = "double"

[tool.ruff.lint.pylint]
max-args = 6

[tool.ruff.lint.pep8-naming]
classmethod-decorators = ["pydantic.validator", "pydantic.root_validator"]

Expand Down
30 changes: 30 additions & 0 deletions sqliter/sqliter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
auto_commit: bool = True,
debug: bool = False,
logger: Optional[logging.Logger] = None,
reset: bool = False,
) -> None:
"""Initialize a new SqliterDB instance.
Expand All @@ -63,6 +64,8 @@ def __init__(
auto_commit: Whether to automatically commit transactions.
debug: Whether to enable debug logging.
logger: Custom logger for debug output.
reset: Whether to reset the database on initialization. This will
basically drop all existing tables.
Raises:
ValueError: If no filename is provided for a non-memory database.
Expand All @@ -81,10 +84,37 @@ def __init__(
self.debug = debug
self.logger = logger
self.conn: Optional[sqlite3.Connection] = None
self.reset = reset

if self.debug:
self._setup_logger()

if self.reset:
self._reset_database()

def _reset_database(self) -> None:
"""Drop all user-created tables in the database."""
with self.connect() as conn:
cursor = conn.cursor()

# Get all table names, excluding SQLite system tables
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' "
"AND name NOT LIKE 'sqlite_%';"
)
tables = cursor.fetchall()

# Drop each user-created table
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table[0]}")

conn.commit()

if self.debug and self.logger:
self.logger.debug(
"Database reset: %s user-created tables dropped.", len(tables)
)

def _setup_logger(self) -> None:
"""Set up the logger for debug output.
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,9 @@ def db_mock_complex_debug() -> SqliterDB:
)
)
return db


@pytest.fixture
def temp_db_path(tmp_path) -> str:
"""Fixture to create a temporary database file path."""
return str(tmp_path / "test_db.sqlite")
7 changes: 7 additions & 0 deletions tests/test_debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,10 @@ def test_debug_output_drop_table(
assert (
"Executing SQL: DROP TABLE IF EXISTS complex_model" in caplog.text
)

def test_reset_database_debug_logging(self, temp_db_path, caplog) -> None:
"""Test that resetting the database logs debug information."""
with caplog.at_level(logging.DEBUG):
SqliterDB(temp_db_path, reset=True, debug=True)

assert "Database reset: 0 user-created tables dropped." in caplog.text
73 changes: 73 additions & 0 deletions tests/test_sqliter.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,76 @@ def test_complex_model_primary_key(self, db_mock) -> None:
f"Expected 'INTEGER' type for primary key, but got "
f"{primary_key_column[2]}"
)

def test_reset_database_on_init(self, temp_db_path) -> None:
"""Test that the database is reset when reset=True is passed."""

class TestModel(BaseDBModel):
name: str

class Meta:
table_name = "test_reset_table"

# Create a database and add some data
db = SqliterDB(temp_db_path)
db.create_table(TestModel)
db.insert(TestModel(name="Test Data"))
db.close()

# Create a new connection with reset=True
db_reset = SqliterDB(temp_db_path, reset=True)

# Verify the table no longer exists
with pytest.raises(RecordFetchError):
db_reset.select(TestModel).fetch_all()

def test_reset_database_preserves_connection(self, temp_db_path) -> None:
"""Test that resetting the database doesn't break the connection."""

class TestModel(BaseDBModel):
name: str

class Meta:
table_name = "test_reset_table"

db = SqliterDB(temp_db_path, reset=True)

# Create a table after reset
db.create_table(TestModel)
db.insert(TestModel(name="New Data"))

# Verify data exists
result = db.select(TestModel).fetch_all()
assert len(result) == 1

def test_reset_database_with_multiple_tables(self, temp_db_path) -> None:
"""Test that reset drops all tables in the database."""

class TestModel1(BaseDBModel):
name: str

class Meta:
table_name = "test_reset_table1"

class TestModel2(BaseDBModel):
age: int

class Meta:
table_name = "test_reset_table2"

# Create a database and add some data
db = SqliterDB(temp_db_path)
db.create_table(TestModel1)
db.create_table(TestModel2)
db.insert(TestModel1(name="Test Data"))
db.insert(TestModel2(age=25))
db.close()

# Reset the database
db_reset = SqliterDB(temp_db_path, reset=True)

# Verify both tables no longer exist
with pytest.raises(RecordFetchError):
db_reset.select(TestModel1).fetch_all()
with pytest.raises(RecordFetchError):
db_reset.select(TestModel2).fetch_all()

0 comments on commit cdd369e

Please sign in to comment.