-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add a 'Unique' method and basic test
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
- Loading branch information
Showing
4 changed files
with
85 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
"""This module provides the base model class for SQLiter database models. | ||
It exports the BaseDBModel class, which is used to define database | ||
models in SQLiter applications. | ||
models in SQLiter applications, and the Unique class, which is used to | ||
define unique constraints on model fields. | ||
""" | ||
|
||
from .model import BaseDBModel | ||
from .unique import Unique | ||
|
||
__all__ = ["BaseDBModel"] | ||
__all__ = ["BaseDBModel", "Unique"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Define a custom field type for unique constraints in SQLiter.""" | ||
|
||
from typing import Any | ||
|
||
from pydantic.fields import FieldInfo | ||
|
||
|
||
class Unique(FieldInfo): | ||
"""A custom field type for unique constraints in SQLiter.""" | ||
|
||
def __init__(self, default: Any = ..., **kwargs: Any) -> None: # noqa: ANN401 | ||
"""Initialize a Unique field. | ||
Args: | ||
default: The default value for the field. | ||
**kwargs: Additional keyword arguments to pass to FieldInfo. | ||
""" | ||
super().__init__(default=default, **kwargs) | ||
self.unique = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Test the Unique constraint.""" | ||
|
||
from typing import Annotated | ||
|
||
import pytest | ||
|
||
from sqliter import SqliterDB | ||
from sqliter.exceptions import SqliterError | ||
from sqliter.model import BaseDBModel | ||
from sqliter.model.unique import Unique | ||
|
||
|
||
class TestUnique: | ||
"""Test suite for the Unique constraint.""" | ||
|
||
def test_unique_constraint(self) -> None: | ||
"""Test that the Unique constraint is properly applied.""" | ||
|
||
class User(BaseDBModel): | ||
name: str | ||
email: Annotated[str, Unique()] | ||
|
||
db = SqliterDB(":memory:") | ||
db.create_table(User) | ||
|
||
# Insert a user successfully | ||
user1 = User(name="Alice", email="alice@example.com") | ||
db.insert(user1) | ||
|
||
# Attempt to insert a user with the same email | ||
user2 = User(name="Bob", email="alice@example.com") | ||
|
||
with pytest.raises(SqliterError) as excinfo: | ||
db.insert(user2) | ||
|
||
assert "UNIQUE constraint failed: users.email" in str(excinfo.value) | ||
|
||
# Verify that only one user was inserted | ||
users = db.select(User).fetch_all() | ||
assert len(users) == 1 | ||
assert users[0].name == "Alice" | ||
assert users[0].email == "alice@example.com" | ||
|
||
# Insert a user with a different email successfully | ||
user3 = User(name="Charlie", email="charlie@example.com") | ||
db.insert(user3) | ||
|
||
# Verify that two users are now in the database | ||
users = db.select(User).fetch_all() | ||
assert len(users) == 2 | ||
assert {u.name for u in users} == {"Alice", "Charlie"} | ||
assert {u.email for u in users} == { | ||
"alice@example.com", | ||
"charlie@example.com", | ||
} |