diff --git a/sqliter/model/__init__.py b/sqliter/model/__init__.py index 0c1a48d..463b622 100644 --- a/sqliter/model/__init__.py +++ b/sqliter/model/__init__.py @@ -1,9 +1,11 @@ """This module provides the base model class for SQLiter database models. It exports the BaseDBModel class, which is used to define database -models in SQLiter applications. +models in SQLiter applications, and the Unique class, which is used to +define unique constraints on model fields. """ from .model import BaseDBModel +from .unique import Unique -__all__ = ["BaseDBModel"] +__all__ = ["BaseDBModel", "Unique"] diff --git a/sqliter/model/unique.py b/sqliter/model/unique.py new file mode 100644 index 0000000..c02651a --- /dev/null +++ b/sqliter/model/unique.py @@ -0,0 +1,19 @@ +"""Define a custom field type for unique constraints in SQLiter.""" + +from typing import Any + +from pydantic.fields import FieldInfo + + +class Unique(FieldInfo): + """A custom field type for unique constraints in SQLiter.""" + + def __init__(self, default: Any = ..., **kwargs: Any) -> None: # noqa: ANN401 + """Initialize a Unique field. + + Args: + default: The default value for the field. + **kwargs: Additional keyword arguments to pass to FieldInfo. + """ + super().__init__(default=default, **kwargs) + self.unique = True diff --git a/sqliter/sqliter.py b/sqliter/sqliter.py index 18f43b9..b5b7632 100644 --- a/sqliter/sqliter.py +++ b/sqliter/sqliter.py @@ -27,6 +27,7 @@ TableDeletionError, ) from sqliter.helpers import infer_sqlite_type +from sqliter.model.unique import Unique from sqliter.query.query import QueryBuilder if TYPE_CHECKING: # pragma: no cover @@ -239,7 +240,12 @@ def create_table( for field_name, field_info in model_class.model_fields.items(): if field_name != primary_key: sqlite_type = infer_sqlite_type(field_info.annotation) - fields.append(f"{field_name} {sqlite_type}") + unique_constraint = ( + "UNIQUE" if isinstance(field_info, Unique) else "" + ) + fields.append( + f"{field_name} {sqlite_type} {unique_constraint}".strip() + ) create_str = ( "CREATE TABLE IF NOT EXISTS" if exists_ok else "CREATE TABLE" diff --git a/tests/test_unique.py b/tests/test_unique.py new file mode 100644 index 0000000..977549c --- /dev/null +++ b/tests/test_unique.py @@ -0,0 +1,55 @@ +"""Test the Unique constraint.""" + +from typing import Annotated + +import pytest + +from sqliter import SqliterDB +from sqliter.exceptions import SqliterError +from sqliter.model import BaseDBModel +from sqliter.model.unique import Unique + + +class TestUnique: + """Test suite for the Unique constraint.""" + + def test_unique_constraint(self) -> None: + """Test that the Unique constraint is properly applied.""" + + class User(BaseDBModel): + name: str + email: Annotated[str, Unique()] + + db = SqliterDB(":memory:") + db.create_table(User) + + # Insert a user successfully + user1 = User(name="Alice", email="alice@example.com") + db.insert(user1) + + # Attempt to insert a user with the same email + user2 = User(name="Bob", email="alice@example.com") + + with pytest.raises(SqliterError) as excinfo: + db.insert(user2) + + assert "UNIQUE constraint failed: users.email" in str(excinfo.value) + + # Verify that only one user was inserted + users = db.select(User).fetch_all() + assert len(users) == 1 + assert users[0].name == "Alice" + assert users[0].email == "alice@example.com" + + # Insert a user with a different email successfully + user3 = User(name="Charlie", email="charlie@example.com") + db.insert(user3) + + # Verify that two users are now in the database + users = db.select(User).fetch_all() + assert len(users) == 2 + assert {u.name for u in users} == {"Alice", "Charlie"} + assert {u.email for u in users} == { + "alice@example.com", + "charlie@example.com", + }