diff --git a/sqliter/model/model.py b/sqliter/model/model.py index 130c3a8..de67b85 100644 --- a/sqliter/model/model.py +++ b/sqliter/model/model.py @@ -10,7 +10,16 @@ from __future__ import annotations import re -from typing import Any, Optional, TypeVar, Union, cast, get_args, get_origin +from typing import ( + Any, + ClassVar, + Optional, + TypeVar, + Union, + cast, + get_args, + get_origin, +) from pydantic import BaseModel, ConfigDict, Field @@ -41,14 +50,24 @@ class Meta: """Metadata class for configuring database-specific attributes. Attributes: - create_pk (bool): Whether to create a primary key field. - primary_key (str): The name of the primary key field. - table_name (Optional[str]): The name of the database table. + table_name (Optional[str]): The name of the database table. If not + specified, the table name will be inferred from the model class + name and converted to snake_case. + indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of fields + or tuples of fields for which regular (non-unique) indexes + should be created. Indexes improve query performance on these + fields. + unique_indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of + fields or tuples of fields for which unique indexes should be + created. Unique indexes enforce that all values in these fields + are distinct across the table. """ table_name: Optional[str] = ( None # Table name, defaults to class name if not set ) + indexes: ClassVar[list[Union[str, tuple[str]]]] = [] + unique_indexes: ClassVar[list[Union[str, tuple[str]]]] = [] @classmethod def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T: diff --git a/sqliter/sqliter.py b/sqliter/sqliter.py index 5ce7817..2a3d2d5 100644 --- a/sqliter/sqliter.py +++ b/sqliter/sqliter.py @@ -10,7 +10,7 @@ import logging import sqlite3 -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from typing_extensions import Self @@ -261,6 +261,56 @@ def create_table( except sqlite3.Error as exc: raise TableCreationError(table_name) from exc + # Create regular indexes + if hasattr(model_class.Meta, "indexes"): + self._create_indexes( + model_class, model_class.Meta.indexes, unique=False + ) + + # Create unique indexes + if hasattr(model_class.Meta, "unique_indexes"): + self._create_indexes( + model_class, model_class.Meta.unique_indexes, unique=True + ) + + def _create_indexes( + self, + model_class: type[BaseDBModel], + indexes: list[Union[str, tuple[str]]], + *, + unique: bool = False, + ) -> None: + """Helper method to create regular or unique indexes. + + Args: + model_class: The model class defining the table. + indexes: List of fields or tuples of fields to create indexes for. + unique: If True, creates UNIQUE indexes; otherwise, creates regular + indexes. + """ + for index in indexes: + # Handle multiple fields in tuple form + if isinstance(index, tuple): + index_name = "_".join(index) + fields = list(index) # Ensure fields is a list of strings + else: + index_name = index + fields = [index] # Wrap single field in a list + + # Add '_unique' postfix to index name for unique indexes + index_postfix = "_unique" if unique else "" + index_type = ( + "UNIQUE" if unique else "" + ) # Add UNIQUE for unique indexes + + create_index_sql = ( + f"CREATE {index_type} INDEX IF NOT EXISTS " + f"idx_{model_class.get_table_name()}" + f"_{index_name}{index_postfix} " + f"ON {model_class.get_table_name()} ({', '.join(fields)})" + ) + self._execute_sql(create_index_sql) + def _execute_sql(self, sql: str) -> None: """Execute an SQL statement.