Skip to content

Commit

Permalink
implement user-defined indexes
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
  • Loading branch information
seapagan committed Oct 12, 2024
1 parent cf8797f commit 6f6dff2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
27 changes: 23 additions & 4 deletions sqliter/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
from __future__ import annotations

import re
from typing import Any, Optional, TypeVar, Union, cast, get_args, get_origin
from typing import (
Any,
ClassVar,
Optional,
TypeVar,
Union,
cast,
get_args,
get_origin,
)

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -41,14 +50,24 @@ class Meta:
"""Metadata class for configuring database-specific attributes.
Attributes:
create_pk (bool): Whether to create a primary key field.
primary_key (str): The name of the primary key field.
table_name (Optional[str]): The name of the database table.
table_name (Optional[str]): The name of the database table. If not
specified, the table name will be inferred from the model class
name and converted to snake_case.
indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of fields
or tuples of fields for which regular (non-unique) indexes
should be created. Indexes improve query performance on these
fields.
unique_indexes (ClassVar[list[Union[str, tuple[str]]]]): A list of
fields or tuples of fields for which unique indexes should be
created. Unique indexes enforce that all values in these fields
are distinct across the table.
"""

table_name: Optional[str] = (
None # Table name, defaults to class name if not set
)
indexes: ClassVar[list[Union[str, tuple[str]]]] = []
unique_indexes: ClassVar[list[Union[str, tuple[str]]]] = []

@classmethod
def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T:
Expand Down
52 changes: 51 additions & 1 deletion sqliter/sqliter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import logging
import sqlite3
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union

from typing_extensions import Self

Expand Down Expand Up @@ -261,6 +261,56 @@ def create_table(
except sqlite3.Error as exc:
raise TableCreationError(table_name) from exc

# Create regular indexes
if hasattr(model_class.Meta, "indexes"):
self._create_indexes(
model_class, model_class.Meta.indexes, unique=False
)

# Create unique indexes
if hasattr(model_class.Meta, "unique_indexes"):
self._create_indexes(
model_class, model_class.Meta.unique_indexes, unique=True
)

def _create_indexes(
self,
model_class: type[BaseDBModel],
indexes: list[Union[str, tuple[str]]],
*,
unique: bool = False,
) -> None:
"""Helper method to create regular or unique indexes.
Args:
model_class: The model class defining the table.
indexes: List of fields or tuples of fields to create indexes for.
unique: If True, creates UNIQUE indexes; otherwise, creates regular
indexes.
"""
for index in indexes:
# Handle multiple fields in tuple form
if isinstance(index, tuple):
index_name = "_".join(index)
fields = list(index) # Ensure fields is a list of strings
else:
index_name = index
fields = [index] # Wrap single field in a list

# Add '_unique' postfix to index name for unique indexes
index_postfix = "_unique" if unique else ""
index_type = (
"UNIQUE" if unique else ""
) # Add UNIQUE for unique indexes

create_index_sql = (
f"CREATE {index_type} INDEX IF NOT EXISTS "
f"idx_{model_class.get_table_name()}"
f"_{index_name}{index_postfix} "
f"ON {model_class.get_table_name()} ({', '.join(fields)})"
)
self._execute_sql(create_index_sql)

def _execute_sql(self, sql: str) -> None:
"""Execute an SQL statement.
Expand Down

0 comments on commit 6f6dff2

Please sign in to comment.