diff --git a/TODO.md b/TODO.md index 402e551..a3bf01e 100644 --- a/TODO.md +++ b/TODO.md @@ -2,15 +2,11 @@ ## General Plans and Ideas -- add an option to the SQLiter constructor to delete the database file if it - already exists. Default to False. - add attributes to the BaseDBModel to read the table-name, file-name, is-memory etc. - add an 'execute' method to the main class to allow executing arbitrary SQL queries which can be chained to the 'find_first' etc methods or just used directly. -- add a method to drop the entire database easiest way is prob to just delete - and recreate the database file. - add an 'exists_ok' (default True) parameter to the 'create_table' method so it will raise an exception if the table already exists and this is set to False. - add a `rollback` method to the main class to allow manual rollbacks. diff --git a/docs/usage.md b/docs/usage.md index 99fc4bf..8408aea 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -91,6 +91,18 @@ exception will be raised. > db = SqliterDB(":memory:") > ``` +#### Resetting the Database + +If you want to reset the database when you create the SqliterDB object, you can +pass `reset=True`: + +```python +db = SqliterDB("your_database.db", reset=True) +``` + +This will effectively drop all user tables from the database. The file itself is +not deleted, only the tables are dropped. + ### Creating Tables ```python diff --git a/pyproject.toml b/pyproject.toml index 3bdac43..f9c7278 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,9 @@ target-version = "py39" # minimum python version supported indent-style = "space" quote-style = "double" +[tool.ruff.lint.pylint] +max-args = 6 + [tool.ruff.lint.pep8-naming] classmethod-decorators = ["pydantic.validator", "pydantic.root_validator"] diff --git a/sqliter/sqliter.py b/sqliter/sqliter.py index 6e7c9b6..03edd77 100644 --- a/sqliter/sqliter.py +++ b/sqliter/sqliter.py @@ -54,6 +54,7 @@ def __init__( auto_commit: bool = True, debug: bool = False, logger: Optional[logging.Logger] = None, + reset: bool = False, ) -> None: """Initialize a new SqliterDB instance. @@ -63,6 +64,8 @@ def __init__( auto_commit: Whether to automatically commit transactions. debug: Whether to enable debug logging. logger: Custom logger for debug output. + reset: Whether to reset the database on initialization. This will + basically drop all existing tables. Raises: ValueError: If no filename is provided for a non-memory database. @@ -81,10 +84,37 @@ def __init__( self.debug = debug self.logger = logger self.conn: Optional[sqlite3.Connection] = None + self.reset = reset if self.debug: self._setup_logger() + if self.reset: + self._reset_database() + + def _reset_database(self) -> None: + """Drop all user-created tables in the database.""" + with self.connect() as conn: + cursor = conn.cursor() + + # Get all table names, excluding SQLite system tables + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' " + "AND name NOT LIKE 'sqlite_%';" + ) + tables = cursor.fetchall() + + # Drop each user-created table + for table in tables: + cursor.execute(f"DROP TABLE IF EXISTS {table[0]}") + + conn.commit() + + if self.debug and self.logger: + self.logger.debug( + "Database reset: %s user-created tables dropped.", len(tables) + ) + def _setup_logger(self) -> None: """Set up the logger for debug output. diff --git a/tests/conftest.py b/tests/conftest.py index 8b2a88c..4145c45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -194,3 +194,9 @@ def db_mock_complex_debug() -> SqliterDB: ) ) return db + + +@pytest.fixture +def temp_db_path(tmp_path) -> str: + """Fixture to create a temporary database file path.""" + return str(tmp_path / "test_db.sqlite") diff --git a/tests/test_debug_logging.py b/tests/test_debug_logging.py index 2ff4e97..1176cfe 100644 --- a/tests/test_debug_logging.py +++ b/tests/test_debug_logging.py @@ -256,3 +256,10 @@ def test_debug_output_drop_table( assert ( "Executing SQL: DROP TABLE IF EXISTS complex_model" in caplog.text ) + + def test_reset_database_debug_logging(self, temp_db_path, caplog) -> None: + """Test that resetting the database logs debug information.""" + with caplog.at_level(logging.DEBUG): + SqliterDB(temp_db_path, reset=True, debug=True) + + assert "Database reset: 0 user-created tables dropped." in caplog.text diff --git a/tests/test_sqliter.py b/tests/test_sqliter.py index 1d5e2ef..bb01e56 100644 --- a/tests/test_sqliter.py +++ b/tests/test_sqliter.py @@ -720,3 +720,76 @@ def test_complex_model_primary_key(self, db_mock) -> None: f"Expected 'INTEGER' type for primary key, but got " f"{primary_key_column[2]}" ) + + def test_reset_database_on_init(self, temp_db_path) -> None: + """Test that the database is reset when reset=True is passed.""" + + class TestModel(BaseDBModel): + name: str + + class Meta: + table_name = "test_reset_table" + + # Create a database and add some data + db = SqliterDB(temp_db_path) + db.create_table(TestModel) + db.insert(TestModel(name="Test Data")) + db.close() + + # Create a new connection with reset=True + db_reset = SqliterDB(temp_db_path, reset=True) + + # Verify the table no longer exists + with pytest.raises(RecordFetchError): + db_reset.select(TestModel).fetch_all() + + def test_reset_database_preserves_connection(self, temp_db_path) -> None: + """Test that resetting the database doesn't break the connection.""" + + class TestModel(BaseDBModel): + name: str + + class Meta: + table_name = "test_reset_table" + + db = SqliterDB(temp_db_path, reset=True) + + # Create a table after reset + db.create_table(TestModel) + db.insert(TestModel(name="New Data")) + + # Verify data exists + result = db.select(TestModel).fetch_all() + assert len(result) == 1 + + def test_reset_database_with_multiple_tables(self, temp_db_path) -> None: + """Test that reset drops all tables in the database.""" + + class TestModel1(BaseDBModel): + name: str + + class Meta: + table_name = "test_reset_table1" + + class TestModel2(BaseDBModel): + age: int + + class Meta: + table_name = "test_reset_table2" + + # Create a database and add some data + db = SqliterDB(temp_db_path) + db.create_table(TestModel1) + db.create_table(TestModel2) + db.insert(TestModel1(name="Test Data")) + db.insert(TestModel2(age=25)) + db.close() + + # Reset the database + db_reset = SqliterDB(temp_db_path, reset=True) + + # Verify both tables no longer exist + with pytest.raises(RecordFetchError): + db_reset.select(TestModel1).fetch_all() + with pytest.raises(RecordFetchError): + db_reset.select(TestModel2).fetch_all()