Skip to content

Commit dcca4fe

Browse files
committed
add optional 'fields' arg to 'Select()'
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
1 parent 48e07b2 commit dcca4fe

File tree

3 files changed

+87
-14
lines changed

3 files changed

+87
-14
lines changed

sqliter/model/model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,23 @@
22

33
from __future__ import annotations
44

5-
from typing import Optional
5+
from typing import Any, Optional, TypeVar
66

7-
from pydantic import BaseModel
7+
from pydantic import BaseModel, ConfigDict
8+
9+
T = TypeVar("T", bound="BaseDBModel")
810

911

1012
class BaseDBModel(BaseModel):
1113
"""Custom base model for database models."""
1214

15+
model_config = ConfigDict(
16+
extra="ignore",
17+
populate_by_name=True,
18+
validate_assignment=False,
19+
from_attributes=True,
20+
)
21+
1322
class Meta:
1423
"""Configure the base model with default options."""
1524

@@ -19,6 +28,15 @@ class Meta:
1928
None # Table name, defaults to class name if not set
2029
)
2130

31+
@classmethod
32+
def model_validate_partial(cls: type[T], obj: dict[str, Any]) -> T:
33+
"""Validate a partial model object.
34+
35+
This would be in the case that we are only returning a subset of the
36+
fields.
37+
"""
38+
return cls.model_validate(obj, strict=False)
39+
2240
@classmethod
2341
def get_table_name(cls) -> str:
2442
"""Get the table name from the Meta, or default to the classname."""

sqliter/query/query.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,45 @@
3030
class QueryBuilder:
3131
"""Functions to build and execute queries for a given model."""
3232

33-
def __init__(self, db: SqliterDB, model_class: type[BaseDBModel]) -> None:
34-
"""Initialize the query builder with the database, model class, etc."""
33+
def __init__(
34+
self,
35+
db: SqliterDB,
36+
model_class: type[BaseDBModel],
37+
fields: Optional[list[str]] = None,
38+
) -> None:
39+
"""Initialize the query builder.
40+
41+
Pass the database, model class, and optional fields.
42+
43+
Args:
44+
db: The SqliterDB instance.
45+
model_class: The model class to query.
46+
fields: Optional list of field names to select. If None, all fields
47+
are selected.
48+
"""
3549
self.db = db
3650
self.model_class = model_class
3751
self.table_name = model_class.get_table_name() # Use model_class method
3852
self.filters: list[tuple[str, Any, str]] = []
3953
self._limit: Optional[int] = None
4054
self._offset: Optional[int] = None
4155
self._order_by: Optional[str] = None
56+
self._fields: Optional[list[str]] = fields
57+
58+
if self._fields:
59+
self._validate_fields()
60+
61+
def _validate_fields(self) -> None:
62+
"""Validate that the specified fields exist in the model."""
63+
if self._fields is None:
64+
return
65+
valid_fields = set(self.model_class.model_fields.keys())
66+
invalid_fields = set(self._fields) - valid_fields
67+
if invalid_fields:
68+
err_message = (
69+
f"Invalid fields specified: {', '.join(invalid_fields)}"
70+
)
71+
raise ValueError(err_message)
4272

4373
def filter(self, **conditions: str | float | None) -> QueryBuilder:
4474
"""Add filter conditions to the query."""
@@ -219,15 +249,20 @@ def _execute_query(
219249
count_only: bool = False,
220250
) -> list[tuple[Any, ...]] | Optional[tuple[Any, ...]]:
221251
"""Helper function to execute the query with filters."""
222-
fields = ", ".join(self.model_class.model_fields)
252+
if count_only:
253+
fields = "COUNT(*)"
254+
elif self._fields:
255+
fields = ", ".join(f'"{field}"' for field in self._fields)
256+
else:
257+
fields = ", ".join(
258+
f'"{field}"' for field in self.model_class.model_fields
259+
)
260+
261+
sql = f'SELECT {fields} FROM "{self.table_name}"' # noqa: S608 # nosec
223262

224263
# Build the WHERE clause with special handling for None (NULL in SQL)
225264
values, where_clause = self._parse_filter()
226265

227-
select_fields = fields if not count_only else "COUNT(*)"
228-
229-
sql = f'SELECT {select_fields} FROM "{self.table_name}"' # noqa: S608 # nosec
230-
231266
if self.filters:
232267
sql += f" WHERE {where_clause}"
233268

@@ -276,9 +311,17 @@ def fetch_all(self) -> list[BaseDBModel]:
276311
if not results:
277312
return []
278313

314+
if self._fields:
315+
return [
316+
self.model_class.model_validate_partial(
317+
{field: row[idx] for idx, field in enumerate(self._fields)}
318+
)
319+
for row in results
320+
]
321+
279322
return [
280-
self.model_class(
281-
**{
323+
self.model_class.model_validate(
324+
{
282325
field: row[idx]
283326
for idx, field in enumerate(self.model_class.model_fields)
284327
}

sqliter/sqliter.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,21 @@ def delete(
206206
except sqlite3.Error as exc:
207207
raise RecordDeletionError(table_name) from exc
208208

209-
def select(self, model_class: type[BaseDBModel]) -> QueryBuilder:
210-
"""Start a query for the given model."""
211-
return QueryBuilder(self, model_class)
209+
def select(
210+
self, model_class: type[BaseDBModel], fields: Optional[list[str]] = None
211+
) -> QueryBuilder:
212+
"""Start a query for the given model.
213+
214+
Args:
215+
model_class: The model class to query.
216+
fields: Optional list of field names to select.
217+
If None, all fields are selected.
218+
219+
Returns:
220+
QueryBuilder: An instance of QueryBuilder for the given model and
221+
fields.
222+
"""
223+
return QueryBuilder(self, model_class, fields)
212224

213225
# --- Context manager methods ---
214226
def __enter__(self) -> Self:

0 commit comments

Comments
 (0)