diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2ee425..fda8b51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,8 @@ repos: hooks: - id: mypy name: "run mypy" + additional_dependencies: + - pydantic - repo: https://github.com/astral-sh/uv-pre-commit # uv version. diff --git a/README.md b/README.md index a91c431..b09e4c4 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,10 @@ Website](https://sqliter.grantramsay.dev) > [!CAUTION] > This project is still in the early stages of development and is lacking some -> planned functionality. Please use with caution. +> planned functionality. Please use with caution - Classes and methods may +> change until a stable release is made. I'll try to keep this to an absolute +> minimum and the releases and documentation will be very clear about any +> breaking changes. > > Also, structures like `list`, `dict`, `set` etc are not supported **at this > time** as field types, since SQLite does not have a native column type for @@ -45,7 +48,7 @@ Website](https://sqliter.grantramsay.dev) - Table creation based on Pydantic models - CRUD operations (Create, Read, Update, Delete) -- Basic query building with filtering, ordering, and pagination +- Chained Query building with filtering, ordering, and pagination - Transaction support - Custom exceptions for better error handling - Full type hinting and type checking @@ -114,7 +117,7 @@ db.create_table(User) # Insert a record user = User(name="John Doe", age=30) -new_record = db.insert(user) +new_user = db.insert(user) # Query records results = db.select(User).filter(name="John Doe").fetch_all() @@ -122,11 +125,11 @@ for user in results: print(f"User: {user.name}, Age: {user.age}") # Update a record -user.age = 31 -db.update(User, new_record) +new_user.age = 31 +db.update(new_user) # Delete a record -db.delete(User, new_record.pk) +db.delete(User, new_user.pk) ``` See the [Usage](https://sqliter.grantramsay.dev/usage) section of the documentation diff --git a/demo.py b/demo.py index d074658..b3c8a6c 100644 --- a/demo.py +++ b/demo.py @@ -41,7 +41,9 @@ def main() -> None: level=logging.DEBUG, format="%(levelname)-8s%(message)s" ) - db = SqliterDB(memory=True, auto_commit=True, debug=True) + db = SqliterDB( + "demo.db", memory=False, auto_commit=True, debug=True, reset=True + ) with db: db.create_table(UserModel) # Create the users table user1 = UserModel( @@ -62,7 +64,7 @@ def main() -> None: ) try: db.insert(user1) - db.insert(user2) + user2_instance = db.insert(user2) db.insert(user3) except RecordInsertionError as exc: logging.error(exc) # noqa: TRY400 @@ -79,8 +81,11 @@ def main() -> None: ) logging.info(all_reversed) - fetched_user = db.get(UserModel, "jdoe2") - logging.info(fetched_user) + if user2_instance is None: + logging.error("User2 ID not found.") + else: + fetched_user = db.get(UserModel, user2_instance.pk) + logging.info("Fetched (%s)", fetched_user) count = db.select(UserModel).count() logging.info("Total Users: %s", count) diff --git a/docs/guide/data-ops.md b/docs/guide/data-ops.md index a96f9fb..f513447 100644 --- a/docs/guide/data-ops.md +++ b/docs/guide/data-ops.md @@ -8,7 +8,16 @@ into the correct table: ```python user = User(name="Jane Doe", age=25, email="jane@example.com") -db.insert(user) +result = db.insert(user) +``` + +The `result` variable will contain a new instance of the model, with the primary +key value set to the newly-created primary key in the database. You should use +this instance to access the primary key value and other fields: + +```python +print(f"New record inserted with primary key: {result.pk}") +print(f"Name: {result.name}, Age: {result.age}, Email: {result.email}") ``` > [!IMPORTANT] @@ -46,6 +55,10 @@ See [Filtering Results](filtering.md) for more advanced filtering options. ## Updating Records +You can update records in the database by modifying the fields of the model +instance and then calling the `update()` method. You just pass the model +instance to the method: + ```python user.age = 26 db.update(user) @@ -53,8 +66,11 @@ db.update(user) ## Deleting Records +To delete a record from the database, you need to pass the model class and the +primary key value of the record you want to delete: + ```python -db.delete(User, "Jane Doe") +db.delete(User, user.pk) ``` ## Commit your changes diff --git a/docs/guide/filtering.md b/docs/guide/filtering.md index e42f50a..ec279d2 100644 --- a/docs/guide/filtering.md +++ b/docs/guide/filtering.md @@ -8,6 +8,17 @@ records, and can be combined with other methods like `order()`, `limit()`, and result = db.select(User).filter(age__lte=30).limit(10).fetch_all() ``` +It is possible to both add multiple filters in the same call, and to chain +multiple filter calls together: + +```python +result = db.select(User).filter(age__gte=20, age__lte=30).fetch_all() +``` + +```python +result = db.select(User).filter(age__gte=20).filter(age__lte=30).fetch_all() +``` + ## Basic Filters - `__eq`: Equal to (default if no operator is specified) diff --git a/docs/guide/guide.md b/docs/guide/guide.md index 47b4e87..be1dde5 100644 --- a/docs/guide/guide.md +++ b/docs/guide/guide.md @@ -41,9 +41,12 @@ Inserting records is straightforward with SQLiter: ```python user = User(name="John Doe", age=30, email="john@example.com") -db.insert(user) +new_record = db.insert(user) ``` +If successful, `new_record` will contain a model the same as was passed to it, +but including the newly-created primary key value. + ## Basic Queries You can easily query all records from a table: @@ -109,7 +112,8 @@ returned. ## Updating Records -Records can be updated seamlessly: +Records can be updated seamlessly. Simply modify the fields of the model +instance and pass that to the `update()` method: ```python user.age = 31 @@ -118,12 +122,23 @@ db.update(user) ## Deleting Records -Deleting records is simple as well: +Deleting records is simple as well. You just need to pass the Model that defines +your table and the primary key value of the record you want to delete: ```python -db.delete(User, "John Doe") +db.delete(User, 1) ``` +> [!NOTE] +> +> You can get the primary key value from the record or model instance itself, +> e.g., `new_record.pk` and pass that as the second argument to the `delete()` +> method: +> +> ```python +> db.delete(User, new_record.pk) +> ``` + ## Advanced Query Features ### Ordering diff --git a/docs/guide/models.md b/docs/guide/models.md index 4d30c06..e4c4145 100644 --- a/docs/guide/models.md +++ b/docs/guide/models.md @@ -1,8 +1,14 @@ -# Defining Models +# Models -Models in SQLiter use Pydantic to encapsulate the logic. All models should -inherit from SQLiter's `BaseDBModel`. You can define your -models like this: +Each individual table in your database should be represented by a model. This +model should inherit from `BaseDBModel` and define the fields that should be +stored in the table. Under the hood, the model is a Pydantic model, so you can +use all the features of Pydantic models, such as default values, type hints, and +validation. + +## Defining Models + +Models are defined like this: ```python from sqliter.model import BaseDBModel @@ -11,23 +17,36 @@ class User(BaseDBModel): name: str age: int email: str - - class Meta: - table_name = "users" - primary_key = "name" # Default is "id" - create_pk = False # disable auto-creating an incrementing primary key - default is True ``` -For a standard database with an auto-incrementing integer `id` primary key, you -do not need to specify the `primary_key` or `create_pk` fields. If you want to -specify a different primary key field name, you can do so using the -`primary_key` field in the `Meta` class. +You can create as many Models as you need, each representing a different table +in your database. The fields in the model will be used to create the columns in +the table. + +> [!IMPORTANT] +> +> - Type-hints are **REQUIRED** for each field in the model. +> - The Model **automatically** creates an **auto-incrementing integer primary +> key** for each table called `pk`, you do not need to define it yourself. + +### Custom Table Name + +By default, the table name will be the same as the model name, converted to +'snake_case' and pluralized (e.g., `User` -> `users`). Also, any 'Model' suffix +will be removed (e.g., `UserModel` -> `users`). To override this behavior, you +can specify the `table_name` in the `Meta` class manually as below: + +```python +from sqliter.model import BaseDBModel + +class User(BaseDBModel): + name: str + age: int + email: str -If `table_name` is not specified, the table name will be the same as the model -name, converted to 'snake_case' and pluralized (e.g., `User` -> `users`). Also, -any 'Model' suffix will be removed (e.g., `UserModel` -> `users`). To override -this behavior, you can specify the `table_name` in the `Meta` class manually as -above. + class Meta: + table_name = "people" +``` > [!NOTE] > @@ -36,3 +55,27 @@ above. > you need more advanced pluralization, you can install the `extras` package as > mentioned in the [installation](../installation.md#optional-dependencies). Of > course, you can always specify the `table_name` manually in this case! + +## Model Classmethods + +There are 2 useful methods you can call on your models. Note that they are +**Class Methods** so should be called on the Model class itself, not an +instance of the model: + +### `get_table_name()` + +This method returns the actual table name for the model either specified or +automatically generated. This is useful if you need to do any raw SQL queries. + +```python +table_name = User.get_table_name() +``` + +### `get_primary_key()` + +This simply returns the name of the primary key for that table. At the moment, +this will always return the string `pk` but this may change in the future. + +```python +primary_key = User.get_primary_key() +``` diff --git a/docs/index.md b/docs/index.md index bf8bed0..c45e465 100644 --- a/docs/index.md +++ b/docs/index.md @@ -22,7 +22,10 @@ database-like format without needing to learn SQL or use a full ORM. > [!CAUTION] > This project is still in the early stages of development and is lacking some -> planned functionality. Please use with caution. +> planned functionality. Please use with caution - Classes and methods may +> change until a stable release is made. I'll try to keep this to an absolute +> minimum and the releases and documentation will be very clear about any +> breaking changes. > > Also, structures like `list`, `dict`, `set` etc are not supported **at this > time** as field types, since SQLite does not have a native column type for @@ -36,7 +39,7 @@ database-like format without needing to learn SQL or use a full ORM. - Table creation based on Pydantic models - CRUD operations (Create, Read, Update, Delete) -- Basic query building with filtering, ordering, and pagination +- Chained Query building with filtering, ordering, and pagination - Transaction support - Custom exceptions for better error handling - Full type hinting and type checking diff --git a/docs/quickstart.md b/docs/quickstart.md index e5a6ead..21f63b9 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -26,7 +26,7 @@ db.create_table(User) # Insert a record user = User(name="John Doe", age=30) -db.insert(user) +new_user = db.insert(user) # Query records results = db.select(User).filter(name="John Doe").fetch_all() @@ -34,11 +34,15 @@ for user in results: print(f"User: {user.name}, Age: {user.age}, Admin: {user.admin}") # Update a record -user.age = 31 -db.update(user) +new_user.age = 31 +db.update(new_user) + +results = db.select(User).filter(name="John Doe").fetch_one() + +print("Updated age:", results.age) # Delete a record -db.delete(User, "John Doe") +db.delete(User, new_user.pk) ``` See the [Guide](guide/guide.md) for more detailed information on how to use `SQLiter`. diff --git a/mkdocs.yml b/mkdocs.yml index 647e769..c0dd678 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,6 +12,7 @@ theme: - navigation.tabs - navigation.sections - navigation.indexes + - content.code.copy extra: social: diff --git a/pyproject.toml b/pyproject.toml index 047e36f..9c28bad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,9 +144,10 @@ known-first-party = ["sqliter"] keep-runtime-typing = true [tool.mypy] +plugins = ["pydantic.mypy"] + python_version = "3.9" exclude = ["docs"] - [[tool.mypy.overrides]] disable_error_code = ["method-assign", "no-untyped-def", "attr-defined"] module = "tests.*" diff --git a/sqliter/exceptions.py b/sqliter/exceptions.py index cf49652..f248a94 100644 --- a/sqliter/exceptions.py +++ b/sqliter/exceptions.py @@ -114,7 +114,7 @@ class RecordUpdateError(SqliterError): class RecordNotFoundError(SqliterError): """Exception raised when a requested record is not found in the database.""" - message_template = "Failed to find a record for key '{}' " + message_template = "Failed to find that record in the table (key '{}') " class RecordFetchError(SqliterError): diff --git a/sqliter/model/model.py b/sqliter/model/model.py index 1ceca4f..130c3a8 100644 --- a/sqliter/model/model.py +++ b/sqliter/model/model.py @@ -10,9 +10,9 @@ from __future__ import annotations import re -from typing import Any, Optional, TypeVar, Union, get_args, get_origin +from typing import Any, Optional, TypeVar, Union, cast, get_args, get_origin -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field T = TypeVar("T", bound="BaseDBModel") @@ -28,6 +28,8 @@ class BaseDBModel(BaseModel): representing database models. """ + pk: int = Field(0, description="The mandatory primary key of the table.") + model_config = ConfigDict( extra="ignore", populate_by_name=True, @@ -44,10 +46,6 @@ class Meta: table_name (Optional[str]): The name of the database table. """ - create_pk: bool = ( - True # Whether to create an auto-increment primary key - ) - primary_key: str = "id" # Default primary key name table_name: Optional[str] = ( None # Table name, defaults to class name if not set ) @@ -89,7 +87,7 @@ def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T: else: converted_obj[field_name] = field_type(value) - return cls.model_construct(**converted_obj) + return cast(T, cls.model_construct(**converted_obj)) @classmethod def get_table_name(cls) -> str: @@ -127,18 +125,10 @@ def get_table_name(cls) -> str: @classmethod def get_primary_key(cls) -> str: - """Get the primary key field name for the model. - - Returns: - The name of the primary key field. - """ - return getattr(cls.Meta, "primary_key", "id") + """Returns the mandatory primary key, always 'pk'.""" + return "pk" @classmethod def should_create_pk(cls) -> bool: - """Determine if a primary key should be automatically created. - - Returns: - True if a primary key should be created, False otherwise. - """ - return getattr(cls.Meta, "create_pk", True) + """Returns True since the primary key is always created.""" + return True diff --git a/sqliter/query/query.py b/sqliter/query/query.py index 704c30d..a4ba9de 100644 --- a/sqliter/query/query.py +++ b/sqliter/query/query.py @@ -145,6 +145,8 @@ def fields(self, fields: Optional[list[str]] = None) -> QueryBuilder: The QueryBuilder instance for method chaining. """ if fields: + if "pk" not in fields: + fields.append("pk") self._fields = fields self._validate_fields() return self @@ -164,6 +166,9 @@ def exclude(self, fields: Optional[list[str]] = None) -> QueryBuilder: invalid fields are specified. """ if fields: + if "pk" in fields: + err = "The primary key 'pk' cannot be excluded." + raise ValueError(err) all_fields = set(self.model_class.model_fields.keys()) # Check for invalid fields before subtraction @@ -179,7 +184,7 @@ def exclude(self, fields: Optional[list[str]] = None) -> QueryBuilder: self._fields = list(all_fields - set(fields)) # Explicit check: raise an error if no fields remain - if not self._fields: + if self._fields == ["pk"]: err = "Exclusion results in no fields being selected." raise ValueError(err) @@ -208,7 +213,7 @@ def only(self, field: str) -> QueryBuilder: raise ValueError(err) # Set self._fields to just the single field - self._fields = [field] + self._fields = [field, "pk"] return self def _get_operator_handler( @@ -527,6 +532,8 @@ def _execute_query( if count_only: fields = "COUNT(*)" elif self._fields: + if "pk" not in self._fields: + self._fields.append("pk") fields = ", ".join(f'"{field}"' for field in self._fields) else: fields = ", ".join( diff --git a/sqliter/sqliter.py b/sqliter/sqliter.py index 7e2a392..62bf605 100644 --- a/sqliter/sqliter.py +++ b/sqliter/sqliter.py @@ -10,7 +10,7 @@ import logging import sqlite3 -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, TypeVar from typing_extensions import Self @@ -33,6 +33,8 @@ from sqliter.model.model import BaseDBModel +T = TypeVar("T", bound="BaseDBModel") + class SqliterDB: """Main class for interacting with SQLite databases. @@ -223,28 +225,12 @@ def create_table( """ table_name = model_class.get_table_name() primary_key = model_class.get_primary_key() - create_pk = model_class.should_create_pk() if force: drop_table_sql = f"DROP TABLE IF EXISTS {table_name}" self._execute_sql(drop_table_sql) - fields = [] - - # Always add the primary key field first - if create_pk: - fields.append(f"{primary_key} INTEGER PRIMARY KEY AUTOINCREMENT") - else: - field_info = model_class.model_fields.get(primary_key) - if field_info is not None: - sqlite_type = infer_sqlite_type(field_info.annotation) - fields.append(f"{primary_key} {sqlite_type} PRIMARY KEY") - else: - err = ( - f"Primary key field '{primary_key}' not found in model " - "fields." - ) - raise ValueError(err) + fields = [f'"{primary_key}" INTEGER PRIMARY KEY AUTOINCREMENT'] # Add remaining fields for field_name, field_info in model_class.model_fields.items(): @@ -325,19 +311,28 @@ def _maybe_commit(self) -> None: if self.auto_commit and self.conn: self.conn.commit() - def insert(self, model_instance: BaseDBModel) -> None: + def insert(self, model_instance: T) -> T: """Insert a new record into the database. Args: - model_instance: An instance of a Pydantic model to be inserted. + model_instance: The instance of the model class to insert. + + Returns: + The updated model instance with the primary key (pk) set. Raises: - RecordInsertionError: If there's an error inserting the record. + RecordInsertionError: If an error occurs during the insertion. """ model_class = type(model_instance) table_name = model_class.get_table_name() + # Get the data from the model data = model_instance.model_dump() + # remove the primary key field if it exists, otherwise we'll get + # TypeErrors as multiple primary keys will exist + if data.get("pk", None) == 0: + data.pop("pk") + fields = ", ".join(data.keys()) placeholders = ", ".join( ["?" if value is not None else "NULL" for value in data.values()] @@ -354,11 +349,15 @@ def insert(self, model_instance: BaseDBModel) -> None: cursor = conn.cursor() cursor.execute(insert_sql, values) self._maybe_commit() + except sqlite3.Error as exc: raise RecordInsertionError(table_name) from exc + else: + data.pop("pk", None) + return model_class(pk=cursor.lastrowid, **data) def get( - self, model_class: type[BaseDBModel], primary_key_value: str + self, model_class: type[BaseDBModel], primary_key_value: int ) -> BaseDBModel | None: """Retrieve a single record from the database by its primary key. @@ -405,11 +404,12 @@ def update(self, model_instance: BaseDBModel) -> None: model_instance: An instance of a Pydantic model to be updated. Raises: - RecordUpdateError: If there's an error updating the record. - RecordNotFoundError: If the record to update is not found. + RecordUpdateError: If there's an error updating the record or if it + is not found. """ model_class = type(model_instance) table_name = model_class.get_table_name() + primary_key = model_class.get_primary_key() fields = ", ".join( diff --git a/tests/conftest.py b/tests/conftest.py index 4145c45..2e48dcd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union @@ -16,6 +17,12 @@ memory_db = ":memory:" +@pytest.hookimpl(tryfirst=True) +def pytest_configure(config) -> None: + """Clear the screen before running tests.""" + os.system("cls" if os.name == "nt" else "clear") # noqa: S605 + + @contextmanager def not_raises(exception) -> Generator[None, Any, None]: """Fake a pytest.raises context manager that does not raise an exception. @@ -39,8 +46,6 @@ class ExampleModel(BaseDBModel): class Meta: """Configuration for the model.""" - create_pk: bool = False - primary_key: str = "slug" table_name: str = "test_table" @@ -53,9 +58,7 @@ class PersonModel(BaseDBModel): class Meta: """Configuration for the model.""" - create_pk = False table_name = "person_table" - primary_key = "name" class DetailedPersonModel(BaseDBModel): @@ -72,14 +75,11 @@ class Meta: """Configuration for the model.""" table_name = "detailed_person_table" - primary_key = "name" - create_pk = False class ComplexModel(BaseDBModel): """Model to test complex field types.""" - id: int name: str age: float is_active: bool @@ -90,8 +90,6 @@ class Meta: """Configuration for the model.""" table_name = "complex_model" - primary_key = "id" - create_pk = False @pytest.fixture diff --git a/tests/test_debug_logging.py b/tests/test_debug_logging.py index 1176cfe..f4a7eb6 100644 --- a/tests/test_debug_logging.py +++ b/tests/test_debug_logging.py @@ -39,7 +39,7 @@ def test_debug_sql_output_basic_query( # Assert the SQL query was printed assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" WHERE age = 30.5' in caplog.text ) @@ -55,7 +55,7 @@ def test_debug_sql_output_string_values( # Assert the SQL query was printed with the string properly quoted assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" WHERE name = \'Alice\'' in caplog.text ) @@ -71,7 +71,7 @@ def test_debug_sql_output_multiple_conditions( # Assert the SQL query was printed with multiple conditions assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" WHERE name = \'Alice\' AND ' "age = 30.5" in caplog.text ) @@ -87,7 +87,7 @@ def test_debug_sql_output_order_and_limit( # Assert the SQL query was printed with ORDER and LIMIT assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" ORDER BY "age" DESC LIMIT 1' in caplog.text ) @@ -99,7 +99,7 @@ def test_debug_sql_output_with_null_value( with caplog.at_level(logging.DEBUG): db_mock_complex_debug.insert( ComplexModel( - id=4, + pk=4, name="David", age=40.0, is_active=True, @@ -114,7 +114,7 @@ def test_debug_sql_output_with_null_value( # Assert the SQL query was printed with IS NULL assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" WHERE age IS NULL' in caplog.text ) @@ -130,7 +130,8 @@ def test_debug_sql_output_with_fields_single( # Assert the SQL query only selects the 'name' field assert ( - 'Executing SQL: SELECT "name" FROM "complex_model"' in caplog.text + 'Executing SQL: SELECT "name", "pk" FROM "complex_model"' + in caplog.text ) def test_debug_sql_output_with_fields_multiple( @@ -144,7 +145,7 @@ def test_debug_sql_output_with_fields_multiple( # Assert the SQL query only selects the 'name' and 'age' fields assert ( - 'Executing SQL: SELECT "name", "age" FROM "complex_model"' + 'Executing SQL: SELECT "name", "age", "pk" FROM "complex_model"' in caplog.text ) @@ -159,7 +160,7 @@ def test_debug_sql_output_with_fields_and_filter( # Assert the SQL query selects 'name' and 'score' and applies the filter assert ( - 'Executing SQL: SELECT "name", "score" FROM "complex_model" ' + 'Executing SQL: SELECT "name", "score", "pk" FROM "complex_model" ' "WHERE score > 85" in caplog.text ) @@ -198,7 +199,7 @@ def test_manual_logger_respects_debug_flag(self, caplog) -> None: # Assert that log output was captured with the manually passed logger assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' in caplog.text ) @@ -227,7 +228,7 @@ def test_debug_sql_output_no_matching_records( # Assert that the SQL query was logged despite no matching records assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model" WHERE age = 100' in caplog.text ) @@ -241,7 +242,7 @@ def test_debug_sql_output_empty_query( # Assert that the SQL query was logged for a full table scan assert ( - 'Executing SQL: SELECT "id", "name", "age", "is_active", "score", ' + 'Executing SQL: SELECT "pk", "name", "age", "is_active", "score", ' '"nullable_field" FROM "complex_model"' in caplog.text ) diff --git a/tests/test_execeptions.py b/tests/test_execeptions.py index cfd4eb5..0bb3009 100644 --- a/tests/test_execeptions.py +++ b/tests/test_execeptions.py @@ -55,6 +55,7 @@ def test_database_connection_error(self, mocker) -> None: exc_info.value ) + # @pytest.mark.skip(reason="This is no longer a valid test case.") def test_insert_duplicate_primary_key(self, db_mock) -> None: """Test that exception raised when inserting duplicate primary key.""" # Create a model instance with a unique primary key @@ -63,11 +64,11 @@ def test_insert_duplicate_primary_key(self, db_mock) -> None: ) # Insert the record for the first time, should succeed - db_mock.insert(example_model) + result = db_mock.insert(example_model) # Try inserting the same record again, which should raise our exception with pytest.raises(RecordInsertionError) as exc_info: - db_mock.insert(example_model) + db_mock.insert(result) # Verify that the exception message contains the table name assert "Failed to insert record into table: 'test_table'" in str( @@ -98,7 +99,9 @@ def test_update_not_found_error(self, db_mock) -> None: db_mock.update(example_model) # Verify that the exception message contains the table name - assert "Failed to find a record for key 'test'" in str(exc_info.value) + assert "Failed to find that record in the table (key '0')" in str( + exc_info.value + ) def test_update_exception_error(self, db_mock, mocker) -> None: """Test an exception is raised when updating a record with an error.""" diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..eec6245 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,49 @@ +"""Specific tests for the Model and it's methods.""" + +from typing import Optional + +import pytest + +from sqliter.model.model import BaseDBModel + + +class TestBaseDBModel: + """Test the Model and it's methods.""" + + def test_should_create_pk(self) -> None: + """Test that 'should_create_pk' returns True.""" + assert BaseDBModel.should_create_pk() is True + + def test_get_primary_key(self) -> None: + """Test that 'get_primary_key' returns 'pk'.""" + assert BaseDBModel.get_primary_key() == "pk" + + def test_get_table_name_default(self) -> None: + """Test that 'get_table_name' returns the default table name.""" + + class TestModel(BaseDBModel): + pass + + assert TestModel.get_table_name() == "tests" + + def test_get_table_name_custom(self) -> None: + """Test that 'get_table_name' returns the custom table name.""" + + class TestModel(BaseDBModel): + class Meta: + table_name = "custom_table" + + assert TestModel.get_table_name() == "custom_table" + + def test_model_validate_partial(self) -> None: + """Test 'model_validate_partial' with partial data.""" + + class TestModel(BaseDBModel): + name: str + age: Optional[int] + + data = {"name": "John"} + model_instance = TestModel.model_validate_partial(data) + assert model_instance.name == "John" + with pytest.raises(AttributeError): + _ = model_instance.age diff --git a/tests/test_optional_fields_complex_model.py b/tests/test_optional_fields_complex_model.py index 3349390..260c0c7 100644 --- a/tests/test_optional_fields_complex_model.py +++ b/tests/test_optional_fields_complex_model.py @@ -8,11 +8,11 @@ @pytest.fixture def db_mock_complex(db_mock: SqliterDB) -> SqliterDB: - """Ficture for a mock database with a complex model.""" + """Fixture for a mock database with a complex model.""" db_mock.create_table(ComplexModel) db_mock.insert( ComplexModel( - id=1, + pk=1, name="Alice", age=30.5, is_active=True, @@ -22,7 +22,7 @@ def db_mock_complex(db_mock: SqliterDB) -> SqliterDB: ) db_mock.insert( ComplexModel( - id=2, + pk=2, name="Bob", age=25.0, is_active=False, @@ -41,7 +41,7 @@ def test_select_all_fields(self, db_mock_complex: SqliterDB) -> None: results = db_mock_complex.select(ComplexModel).fetch_all() assert len(results) == 2 for result in results: - assert isinstance(result.id, int) + assert isinstance(result.pk, int) assert isinstance(result.name, str) assert isinstance(result.age, float) assert isinstance(result.is_active, bool) @@ -53,13 +53,13 @@ def test_select_all_fields(self, db_mock_complex: SqliterDB) -> None: def test_select_subset_of_fields(self, db_mock_complex: SqliterDB) -> None: """Select a subset of fields and ensure their types are correct.""" - fields = ["id", "name", "age", "is_active", "score"] + fields = ["pk", "name", "age", "is_active", "score"] results = db_mock_complex.select( ComplexModel, fields=fields ).fetch_all() assert len(results) == 2 for result in results: - assert isinstance(result.id, int) + assert isinstance(result.pk, int) assert isinstance(result.name, str) assert isinstance(result.age, float) assert isinstance(result.is_active, bool) @@ -71,13 +71,13 @@ def test_select_with_type_conversion( self, db_mock_complex: SqliterDB ) -> None: """Select a subset of fields and ensure their types are correct.""" - fields = ["id", "age", "is_active", "score"] + fields = ["pk", "age", "is_active", "score"] results = db_mock_complex.select( ComplexModel, fields=fields ).fetch_all() assert len(results) == 2 for result in results: - assert isinstance(result.id, int) + assert isinstance(result.pk, int) assert isinstance(result.age, float) assert isinstance(result.is_active, bool) assert isinstance(result.score, (int, float)) @@ -107,7 +107,7 @@ def test_select_with_union_field(self, db_mock_complex: SqliterDB) -> None: def test_select_with_filtering(self, db_mock_complex: SqliterDB) -> None: """Select fields with a filter.""" - fields = ["id", "name", "age"] + fields = ["pk", "name", "age"] results = ( db_mock_complex.select(ComplexModel, fields=fields) .filter(age__gt=28) @@ -119,7 +119,7 @@ def test_select_with_filtering(self, db_mock_complex: SqliterDB) -> None: def test_select_with_ordering(self, db_mock_complex: SqliterDB) -> None: """Select fields with ordering.""" - fields = ["id", "name", "age"] + fields = ["pk", "name", "age"] results = ( db_mock_complex.select(ComplexModel, fields=fields) .order("age", direction="DESC") diff --git a/tests/test_query.py b/tests/test_query.py index 756c993..84d67c2 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -565,8 +565,8 @@ def test_fetch_result_with_list_of_tuples(self, mocker) -> None: # Create some mock tuples (mimicking database rows) mock_result = [ - ("john", "John", "content"), - ("jane", "Jane", "content"), + ("1", "john", "John", "content"), + ("2", "jane", "Jane", "content"), ] # Mock the _execute_query method on the QueryBuilder instance @@ -581,5 +581,13 @@ def test_fetch_result_with_list_of_tuples(self, mocker) -> None: assert not isinstance(result, list) assert isinstance(result, ExampleModel) assert result == ExampleModel( - slug="john", name="John", content="content" + pk=1, slug="john", name="John", content="content" ) + + def test_exclude_pk_raises_valueerror(self) -> None: + """Test that excluding the primary key raises a ValueError.""" + match_str = "The primary key 'pk' cannot be excluded." + + db = SqliterDB(memory=True) + with pytest.raises(ValueError, match=match_str): + db.select(ExampleModel).exclude(["pk"]) diff --git a/tests/test_sqliter.py b/tests/test_sqliter.py index 5a55e35..5447952 100644 --- a/tests/test_sqliter.py +++ b/tests/test_sqliter.py @@ -51,10 +51,10 @@ def test_data_lost_when_auto_commit_disabled(self) -> None: test_model = ExampleModel( slug="test", name="Test License", content="Test Content" ) - db.insert(test_model) + result = db.insert(test_model) # Ensure the record exists - fetched_license = db.get(ExampleModel, "test") + fetched_license = db.get(ExampleModel, result.pk) assert fetched_license is not None # Close the connection @@ -65,7 +65,7 @@ def test_data_lost_when_auto_commit_disabled(self) -> None: # Ensure the data is lost with pytest.raises(RecordFetchError): - db.get(ExampleModel, "test") + db.get(ExampleModel, result.pk) def test_create_table(self, db_mock) -> None: """Test table creation.""" @@ -73,8 +73,8 @@ def test_create_table(self, db_mock) -> None: cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() - assert len(tables) == 1 - assert tables[0][0] == "test_table" + assert len(tables) == 2 + assert ("test_table",) in tables def test_close_connection(self, db_mock) -> None: """Test closing the connection.""" @@ -95,15 +95,13 @@ def test_commit_changes(self, mocker) -> None: assert mock_conn.commit.called - def test_create_table_with_auto_increment(self, db_mock) -> None: + def test_create_table_with_default_auto_increment(self, db_mock) -> None: """Test table creation with auto-incrementing primary key.""" class AutoIncrementModel(BaseDBModel): name: str class Meta: - create_pk: bool = True # Enable auto-increment ID - primary_key: str = "id" # Default primary key is 'id' table_name: str = "auto_increment_table" # Create the table @@ -116,113 +114,10 @@ class Meta: table_info = cursor.fetchall() # Check that the first column is 'id' and it's an auto-incrementing int - assert table_info[0][1] == "id" # Column name + assert table_info[0][1] == "pk" # Column name assert table_info[0][2] == "INTEGER" # Column type assert table_info[0][5] == 1 # Primary key flag - def test_create_table_with_custom_primary_key(self, db_mock) -> None: - """Test table creation with a custom primary key.""" - - class CustomPKModel(BaseDBModel): - code: str - description: str - - class Meta: - create_pk: bool = False # Disable auto-increment ID - primary_key: str = "code" # Use 'code' as the primary key - table_name: str = "custom_pk_table" - - # Create the table - db_mock.create_table(CustomPKModel) - - # Verify that the table was created with 'code' as the primary key - with db_mock.connect() as conn: - cursor = conn.cursor() - cursor.execute("PRAGMA table_info(custom_pk_table);") - table_info = cursor.fetchall() - - # Check that the primary key is the 'code' column - primary_key_column = next(col for col in table_info if col[1] == "code") - assert primary_key_column[1] == "code" # Column name - assert primary_key_column[5] == 1 # Primary key flag - - def test_create_table_with_custom_auto_increment_pk(self, db_mock) -> None: - """Test table creation with a custom auto-incrementing primary key.""" - - class CustomAutoIncrementPKModel(BaseDBModel): - name: str - - class Meta: - create_pk: bool = True # Enable auto-increment ID - primary_key: str = ( - "custom_id" # Use 'custom_id' as the primary key - ) - table_name: str = "custom_auto_increment_pk_table" - - # Create the table - db_mock.create_table(CustomAutoIncrementPKModel) - - # Check the table schema using PRAGMA - with db_mock.connect() as conn: - cursor = conn.cursor() - cursor.execute("PRAGMA table_info(custom_auto_increment_pk_table);") - table_info = cursor.fetchall() - - # Check that the 'custom_id' column is INTEGER and a primary key - primary_key_column = next( - col for col in table_info if col[1] == "custom_id" - ) - assert primary_key_column[1] == "custom_id" # Column name - assert primary_key_column[2] == "INTEGER" # Column type - assert primary_key_column[5] == 1 # Primary key flag - - # Insert rows to verify that the custom primary key auto-increments - model_instance1 = CustomAutoIncrementPKModel(name="First Entry") - model_instance2 = CustomAutoIncrementPKModel(name="Second Entry") - - db_mock.insert(model_instance1) - db_mock.insert(model_instance2) - - # Fetch the inserted rows and check the 'custom_id' values - with db_mock.connect() as conn: - cursor = conn.cursor() - cursor.execute( - "SELECT custom_id, name FROM custom_auto_increment_pk_table;" - ) - results = cursor.fetchall() - - # Check that the custom_id column auto-incremented - assert results[0][0] == 1 - assert results[1][0] == 2 - assert results[0][1] == "First Entry" - assert results[1][1] == "Second Entry" - - def test_create_table_missing_primary_key(self) -> None: - """Test create_table raises ValueError when primary key is missing.""" - - # Define a model that doesn't have the expected primary key - class NoPKModel(BaseDBModel): - # Intentionally omitting the primary key field, e.g., 'id' or 'slug' - name: str - age: int - - class Meta: - create_pk = False - - # Initialize your SqliterDB instance (adjust if needed) - db = SqliterDB(memory=True) # Assuming memory=True uses an in-memory DB - - # Use pytest.raises to check if ValueError is raised - with pytest.raises( - ValueError, - match="Primary key field 'id' not found in model fields.", - ) as exc_info: - db.create_table(NoPKModel) - - # Check that the error message matches the expected output - assert "Primary key field" in str(exc_info.value) - assert "not found in model fields" in str(exc_info.value) - def test_default_table_name(self, db_mock) -> None: """Test the default table name generation. @@ -282,18 +177,19 @@ def test_insert_license(self, db_mock) -> None: cursor = conn.cursor() cursor.execute("SELECT * FROM test_table WHERE slug = ?", ("mit",)) result = cursor.fetchone() - assert result[0] == "mit" - assert result[1] == "MIT License" - assert result[2] == "MIT License Content" + assert result[0] == 1 + assert result[1] == "mit" + assert result[2] == "MIT License" + assert result[3] == "MIT License Content" def test_fetch_license(self, db_mock) -> None: """Test fetching a license by primary key.""" test_model = ExampleModel( slug="gpl", name="GPL License", content="GPL License Content" ) - db_mock.insert(test_model) + result = db_mock.insert(test_model) - fetched_license = db_mock.get(ExampleModel, "gpl") + fetched_license = db_mock.get(ExampleModel, result.pk) assert fetched_license is not None assert fetched_license.slug == "gpl" assert fetched_license.name == "GPL License" @@ -304,14 +200,14 @@ def test_update(self, db_mock) -> None: test_model = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" ) - db_mock.insert(test_model) + result = db_mock.insert(test_model) # Update license content - test_model.content = "Updated MIT License Content" - db_mock.update(test_model) + result.content = "Updated MIT License Content" + db_mock.update(result) # Fetch and check if updated - fetched_license = db_mock.get(ExampleModel, "mit") + fetched_license = db_mock.get(ExampleModel, result.pk) assert fetched_license.content == "Updated MIT License Content" def test_delete(self, db_mock) -> None: @@ -319,13 +215,13 @@ def test_delete(self, db_mock) -> None: test_model = ExampleModel( slug="mit", name="MIT License", content="MIT License Content" ) - db_mock.insert(test_model) + result = db_mock.insert(test_model) # Delete the record - db_mock.delete(ExampleModel, "mit") + db_mock.delete(ExampleModel, result.pk) # Ensure it no longer exists - fetched_license = db_mock.get(ExampleModel, "mit") + fetched_license = db_mock.get(ExampleModel, result.pk) assert fetched_license is None def test_select_filter(self, db_mock) -> None: @@ -455,14 +351,14 @@ def test_update_existing_record(self, db_mock) -> None: example_model = ExampleModel( slug="test", name="Test License", content="Test Content" ) - db_mock.insert(example_model) + result = db_mock.insert(example_model) # Update the record's content - example_model.content = "Updated Content" - db_mock.update(example_model) + result.content = "Updated Content" + db_mock.update(result) # Fetch the updated record and verify the changes - updated_record = db_mock.get(ExampleModel, "test") + updated_record = db_mock.get(ExampleModel, result.pk) assert updated_record is not None assert updated_record.content == "Updated Content" @@ -480,7 +376,7 @@ def test_update_non_existing_record(self, db_mock) -> None: db_mock.update(example_model) # Check that the correct error message is raised - assert "Failed to find a record for key 'nonexistent'" in str( + assert "Failed to find that record in the table (key '0')" in str( exc_info.value ) @@ -510,13 +406,13 @@ def test_delete_existing_record(self, db_mock) -> None: test_model = ExampleModel( slug="test", name="Test License", content="Test Content" ) - db_mock.insert(test_model) + result = db_mock.insert(test_model) # Now delete the record - db_mock.delete(ExampleModel, "test") + db_mock.delete(ExampleModel, result.pk) # Fetch the deleted record to confirm it's gone - result = db_mock.get(ExampleModel, "test") + result = db_mock.get(ExampleModel, result.pk) assert result is None def test_transaction_commit_success(self, db_mock, mocker) -> None: @@ -685,7 +581,7 @@ def test_complex_model_field_types(self, db_mock) -> None: # Expected types in SQLite (INTEGER, REAL, TEXT, etc.) expected_types = { - "id": "INTEGER", + "pk": "INTEGER", "name": "TEXT", "age": "REAL", "price": "REAL", @@ -727,7 +623,7 @@ def test_complex_model_primary_key(self, db_mock) -> None: # Assert that the primary key is the 'id' field and is an INTEGER assert primary_key_column is not None, "Primary key not found" assert ( - primary_key_column[1] == "id" + primary_key_column[1] == "pk" ), f"Expected 'id' as primary key, but got {primary_key_column[1]}" assert primary_key_column[2] == "INTEGER", ( f"Expected 'INTEGER' type for primary key, but got "