|
30 | 30 | class QueryBuilder:
|
31 | 31 | """Functions to build and execute queries for a given model."""
|
32 | 32 |
|
33 |
| - def __init__(self, db: SqliterDB, model_class: type[BaseDBModel]) -> None: |
34 |
| - """Initialize the query builder with the database, model class, etc.""" |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + db: SqliterDB, |
| 36 | + model_class: type[BaseDBModel], |
| 37 | + fields: Optional[list[str]] = None, |
| 38 | + ) -> None: |
| 39 | + """Initialize the query builder. |
| 40 | +
|
| 41 | + Pass the database, model class, and optional fields. |
| 42 | +
|
| 43 | + Args: |
| 44 | + db: The SqliterDB instance. |
| 45 | + model_class: The model class to query. |
| 46 | + fields: Optional list of field names to select. If None, all fields |
| 47 | + are selected. |
| 48 | + """ |
35 | 49 | self.db = db
|
36 | 50 | self.model_class = model_class
|
37 | 51 | self.table_name = model_class.get_table_name() # Use model_class method
|
38 | 52 | self.filters: list[tuple[str, Any, str]] = []
|
39 | 53 | self._limit: Optional[int] = None
|
40 | 54 | self._offset: Optional[int] = None
|
41 | 55 | self._order_by: Optional[str] = None
|
| 56 | + self._fields: Optional[list[str]] = fields |
| 57 | + |
| 58 | + if self._fields: |
| 59 | + self._validate_fields() |
| 60 | + |
| 61 | + def _validate_fields(self) -> None: |
| 62 | + """Validate that the specified fields exist in the model.""" |
| 63 | + if self._fields is None: |
| 64 | + return |
| 65 | + valid_fields = set(self.model_class.model_fields.keys()) |
| 66 | + invalid_fields = set(self._fields) - valid_fields |
| 67 | + if invalid_fields: |
| 68 | + err_message = ( |
| 69 | + f"Invalid fields specified: {', '.join(invalid_fields)}" |
| 70 | + ) |
| 71 | + raise ValueError(err_message) |
42 | 72 |
|
43 | 73 | def filter(self, **conditions: str | float | None) -> QueryBuilder:
|
44 | 74 | """Add filter conditions to the query."""
|
@@ -219,15 +249,20 @@ def _execute_query(
|
219 | 249 | count_only: bool = False,
|
220 | 250 | ) -> list[tuple[Any, ...]] | Optional[tuple[Any, ...]]:
|
221 | 251 | """Helper function to execute the query with filters."""
|
222 |
| - fields = ", ".join(self.model_class.model_fields) |
| 252 | + if count_only: |
| 253 | + fields = "COUNT(*)" |
| 254 | + elif self._fields: |
| 255 | + fields = ", ".join(f'"{field}"' for field in self._fields) |
| 256 | + else: |
| 257 | + fields = ", ".join( |
| 258 | + f'"{field}"' for field in self.model_class.model_fields |
| 259 | + ) |
| 260 | + |
| 261 | + sql = f'SELECT {fields} FROM "{self.table_name}"' # noqa: S608 # nosec |
223 | 262 |
|
224 | 263 | # Build the WHERE clause with special handling for None (NULL in SQL)
|
225 | 264 | values, where_clause = self._parse_filter()
|
226 | 265 |
|
227 |
| - select_fields = fields if not count_only else "COUNT(*)" |
228 |
| - |
229 |
| - sql = f'SELECT {select_fields} FROM "{self.table_name}"' # noqa: S608 # nosec |
230 |
| - |
231 | 266 | if self.filters:
|
232 | 267 | sql += f" WHERE {where_clause}"
|
233 | 268 |
|
@@ -276,9 +311,17 @@ def fetch_all(self) -> list[BaseDBModel]:
|
276 | 311 | if not results:
|
277 | 312 | return []
|
278 | 313 |
|
| 314 | + if self._fields: |
| 315 | + return [ |
| 316 | + self.model_class.model_validate_partial( |
| 317 | + {field: row[idx] for idx, field in enumerate(self._fields)} |
| 318 | + ) |
| 319 | + for row in results |
| 320 | + ] |
| 321 | + |
279 | 322 | return [
|
280 |
| - self.model_class( |
281 |
| - **{ |
| 323 | + self.model_class.model_validate( |
| 324 | + { |
282 | 325 | field: row[idx]
|
283 | 326 | for idx, field in enumerate(self.model_class.model_fields)
|
284 | 327 | }
|
|
0 commit comments