Skip to content

Commit

Permalink
fix null filter when combined and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Ramsay <seapagan@gmail.com>
  • Loading branch information
seapagan committed Sep 30, 2024
1 parent 450ca66 commit 09e194c
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 8 deletions.
18 changes: 10 additions & 8 deletions sqliter/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,11 @@ def filter(self, **conditions: str | float | None) -> QueryBuilder:
field_name, operator = self._parse_field_operator(field)
self._validate_field(field_name, valid_fields)

handler = self._get_operator_handler(operator)
handler(field_name, value, operator)
if operator in ["__isnull", "__notnull"]:
self._handle_null(field_name, value, operator)
else:
handler = self._get_operator_handler(operator)
handler(field_name, value, operator)

return self

Expand Down Expand Up @@ -280,23 +283,22 @@ def _handle_equality(
self.filters.append((field_name, value, operator))

def _handle_null(
self, field_name: str, _: FilterValue, operator: str
self, field_name: str, value: Union[str, float, None], operator: str
) -> None:
"""Handle IS NULL and IS NOT NULL filter conditions.
Args:
field_name: The name of the field to filter on. _: Placeholder for
unused value parameter.
operator: The operator string ('__isnull' or '__notnull').
value: The value to check for.
This method adds an IS NULL or IS NOT NULL condition to the filters
list.
"""
condition = (
f"{field_name} IS NOT NULL"
if operator == "__notnull"
else f"{field_name} IS NULL"
)
is_null = operator == "__isnull"
check_null = bool(value) if is_null else not bool(value)
condition = f"{field_name} IS {'NOT ' if not check_null else ''}NULL"
self.filters.append((condition, None, operator))

def _handle_in(
Expand Down
89 changes: 89 additions & 0 deletions tests/test_advanced_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,92 @@ def test_filter_with_bad_contains_condition(self, db_mock_adv) -> None:
str(exc_info.value)
== "name requires a string value for '__contains'"
)

def test_multiple_chained_filters(self, db_mock_adv) -> None:
"""Test multiple chained filters."""
# Insert an additional record
db_mock_adv.insert(PersonModel(name="Alex", age=28))

results = (
db_mock_adv.select(PersonModel)
.filter(age__gt=25)
.filter(name__startswith="A")
.fetch_all()
)

assert len(results) == 1
assert results[0].name == "Alex"
assert results[0].age == 28

def test_all_records_with_multiple_inclusive_filters(
self, db_mock_adv
) -> None:
"""Test using multiple filters in same filter() call."""
results = (
db_mock_adv.select(PersonModel)
.filter(age__gte=25, name__isnull=False)
.fetch_all()
)

assert len(results) == 3 # Original assertion
assert {result.name for result in results} == {
"Alice",
"Bob",
"Charlie",
}

def test_name_isnull_and_notnull_filters(self, db_mock_adv) -> None:
"""Test various filters with __isnull and __notnull."""
# Test __isnull=False
results = (
db_mock_adv.select(PersonModel)
.filter(name__isnull=False)
.fetch_all()
)
assert len(results) == 3
assert all(result.name is not None for result in results)

# Test __isnull=True (should return no results as all names are set)
results = (
db_mock_adv.select(PersonModel)
.filter(name__isnull=True)
.fetch_all()
)
assert len(results) == 0

# Test __notnull=True
results = (
db_mock_adv.select(PersonModel)
.filter(name__notnull=True)
.fetch_all()
)
assert len(results) == 3
assert all(result.name is not None for result in results)

# Test __notnull=False (should return no results as all names are set)
results = (
db_mock_adv.select(PersonModel)
.filter(name__notnull=False)
.fetch_all()
)
assert len(results) == 0

# Add a record with null name to test opposite cases
db_mock_adv.insert(PersonModel(name=None, age=40))

# Now test __isnull=True and __notnull=False with the new record
results = (
db_mock_adv.select(PersonModel)
.filter(name__isnull=True)
.fetch_all()
)
assert len(results) == 1
assert results[0].name is None

results = (
db_mock_adv.select(PersonModel)
.filter(name__notnull=False)
.fetch_all()
)
assert len(results) == 1
assert results[0].name is None

0 comments on commit 09e194c

Please sign in to comment.