diff --git a/tests/fields/test_array.py b/tests/fields/test_array.py index f99072787..a4fafa541 100644 --- a/tests/fields/test_array.py +++ b/tests/fields/test_array.py @@ -115,3 +115,23 @@ async def test_contained_by_strs(self): array_str__contained_by=["x", "y", "z"] ).values_list("array_str", flat=True) self.assertEqual(list(found), []) + + async def test_overlap_ints(self): + await testmodels.ArrayFields.create(array=[1, 2, 3]) + await testmodels.ArrayFields.create(array=[2, 3, 4]) + await testmodels.ArrayFields.create(array=[3, 4, 5]) + + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2]).values_list( + "array", flat=True + ) + self.assertEqual(sorted(list(found)), [[1, 2, 3], [2, 3, 4]]) + + found = await testmodels.ArrayFields.filter(array__overlap=[4]).values_list( + "array", flat=True + ) + self.assertEqual(sorted(list(found)), [[2, 3, 4], [3, 4, 5]]) + + found = await testmodels.ArrayFields.filter(array__overlap=[1, 2, 3, 4, 5]).values_list( + "array", flat=True + ) + self.assertEqual(sorted(list(found)), [[1, 2, 3], [2, 3, 4], [3, 4, 5]]) diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index 6127c2498..9720b8a4c 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -10,6 +10,7 @@ from tortoise.contrib.postgres.array_functions import ( postgres_array_contains, postgres_array_contained_by, + postgres_array_overlap, ) from tortoise.contrib.postgres.json_functions import ( postgres_json_contained_by, @@ -24,6 +25,7 @@ from tortoise.filters import ( array_contains, array_contained_by, + array_overlap, insensitive_posix_regex, json_contained_by, json_contains, @@ -44,6 +46,7 @@ class BasePostgresExecutor(BaseExecutor): search: postgres_search, array_contains: postgres_array_contains, array_contained_by: postgres_array_contained_by, + array_overlap: postgres_array_overlap, json_contains: postgres_json_contains, json_contained_by: postgres_json_contained_by, json_filter: postgres_json_filter, diff --git a/tortoise/contrib/postgres/array_functions.py b/tortoise/contrib/postgres/array_functions.py index 97fe4cba9..c6247978e 100644 --- a/tortoise/contrib/postgres/array_functions.py +++ b/tortoise/contrib/postgres/array_functions.py @@ -7,6 +7,7 @@ class PostgresArrayOperators(str, Enum): CONTAINS = "@>" CONTAINED_BY = "<@" + OVERLAP = "&&" def postgres_array_contains(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: @@ -18,3 +19,7 @@ def postgres_array_contains(field: Term, value: Union[Any, Sequence[Any]]) -> Cr def postgres_array_contained_by(field: Term, value: Sequence[Any]) -> Criterion: return BasicCriterion(PostgresArrayOperators.CONTAINED_BY, field, Array(*value)) + + +def postgres_array_overlap(field: Term, value: Sequence[Any]) -> Criterion: + return BasicCriterion(PostgresArrayOperators.OVERLAP, field, Array(*value)) diff --git a/tortoise/filters.py b/tortoise/filters.py index 0cace4bf5..e4f49550b 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -234,6 +234,10 @@ def array_contained_by(field: Term, value: Union[Any, Sequence[Any]]) -> Criteri raise NotImplementedError("must be overridden in each executor") +def array_overlap(field: Term, value: Union[Any, Sequence[Any]]) -> Criterion: + raise NotImplementedError("must be overridden in each executor") + + ############################################################################## # Filter resolvers ############################################################################## @@ -421,6 +425,11 @@ def get_array_filter(field_name: str, source_field: str) -> dict[str, FilterInfo "source_field": source_field, "operator": array_contained_by, }, + f"{field_name}__overlap": { + "field": field_name, + "source_field": source_field, + "operator": array_overlap, + }, }