diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 7b40224..b115500 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.6.2 +current_version = 0.6.4 commit = True tag = False diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a9b6d66..9bbd03b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,7 +2,7 @@ name: Build on: push: - branches: [ main, '*' ] + branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: diff --git a/makefile b/Makefile similarity index 67% rename from makefile rename to Makefile index 18f39d8..9498208 100644 --- a/makefile +++ b/Makefile @@ -1,22 +1,24 @@ sources = sqlalchemy_easy_softdelete -.PHONY: test format lint unittest coverage pre-commit clean -test: lint unittest - +.PHONY: lint test coverage clean lint: - flake8 $(sources) tests + pre-commit run --all-files -unittest: +test: pytest coverage: pytest --cov=$(sources) --cov-branch --cov-report=term-missing --cov-report=xml tests -pre-commit: - pre-commit run --all-files - clean: rm -rf .pytest_cache rm -rf *.egg-info rm -rf .tox dist site rm -rf coverage.xml .coverage + +dev: + # Start Postgres Instance + docker compose up -d pg + +bump_patch: + bump2version patch --no-tag diff --git a/pyproject.toml b/pyproject.toml index dc39cb8..94fa12e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool] [tool.poetry] name = "sqlalchemy-easy-softdelete" -version = "0.6.2" +version = "0.6.4" homepage = "https://github.com/flipbit03/sqlalchemy-easy-softdelete" description = "Easily add soft-deletion to your SQLAlchemy Models." authors = ["Cadu "] diff --git a/setup.cfg b/setup.cfg index 384b46d..412b380 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] max-line-length = 120 max-complexity = 18 -ignore = E203, E266, W503, TYP001 +ignore = E203, E266, W503, TYP001, D202 docstring-convention = pep257 per-file-ignores = __init__.py:F401 exclude = .git, diff --git a/sqlalchemy_easy_softdelete/__init__.py b/sqlalchemy_easy_softdelete/__init__.py index 9bece67..aeeef0d 100644 --- a/sqlalchemy_easy_softdelete/__init__.py +++ b/sqlalchemy_easy_softdelete/__init__.py @@ -2,4 +2,4 @@ __author__ = """Cadu""" __email__ = 'cadu.coelho@gmail.com' -__version__ = '0.6.2' +__version__ = '0.6.4' diff --git a/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py b/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py index 640d841..bcd4931 100644 --- a/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py +++ b/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py @@ -5,10 +5,10 @@ from sqlalchemy import Table from sqlalchemy.orm import FromStatement from sqlalchemy.orm.util import _ORMJoin -from sqlalchemy.sql import Alias, CompoundSelect, Join, Select, Subquery, TableClause +from sqlalchemy.sql import Alias, CompoundSelect, Executable, Join, Select, Subquery, TableClause from sqlalchemy.sql.elements import TextClause -Statement = TypeVar('Statement', bound=Union[Select, FromStatement]) +Statement = TypeVar('Statement', bound=Union[Select, FromStatement, CompoundSelect, Executable]) class SoftDeleteQueryRewriter: @@ -37,6 +37,11 @@ def rewrite_statement(self, stmt: Statement) -> Statement: if isinstance(stmt, Select): return self.rewrite_select(stmt) + # Handle CompoundSelect + if isinstance(stmt, CompoundSelect): + return self.rewrite_compound_select(stmt) + + # Handle FromStatement which is also a Select/Executable if isinstance(stmt, FromStatement): # Explicitly protect against INSERT with RETURNING if not isinstance(stmt.element, Select): @@ -80,6 +85,8 @@ def rewrite_element(self, subquery: Subquery) -> Subquery: raise NotImplementedError(f"Unsupported object \"{(type(subquery.element))}\" in subquery.element") def rewrite_from_orm_join(self, stmt: Select, join_obj: Union[_ORMJoin, Join]) -> Select: + """Handle multiple, and potentially recursive joins.""" + # Recursive cases (multiple joins) if isinstance(join_obj.left, _ORMJoin) or isinstance(join_obj.left, Join): stmt = self.rewrite_from_orm_join(stmt, join_obj.left) diff --git a/tests/conftest.py b/tests/conftest.py index 14d8536..2827d09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,18 @@ env_connection_string = os.environ.get("TEST_CONNECTION_STRING", None) -test_db_url = env_connection_string or "sqlite://" + +@pytest.fixture +def sqla2_warnings() -> Engine: + # Enable SQLAlchemy 2.0 Warnings mode to help with 2.0 support + os.environ["SQLALCHEMY_WARN_20"] = "1" @pytest.fixture -def db_engine() -> Engine: +def db_engine(sqla2_warnings) -> Engine: + test_db_url = env_connection_string or "sqlite://" print(f"connection_string={test_db_url}") - return create_engine(test_db_url) + return create_engine(test_db_url, future=True) @pytest.fixture diff --git a/tests/seed_data/parent_child_childchild.py b/tests/seed_data/parent_child_childchild.py index 2ae7a2e..f402710 100644 --- a/tests/seed_data/parent_child_childchild.py +++ b/tests/seed_data/parent_child_childchild.py @@ -32,7 +32,7 @@ def generate_parent_child_object_hierarchy( # Create Children (SDChild) for child_no, child_deleted in enumerate(children): new_child_id = parent_id * 100 + child_no - new_child = SDChild(id=new_child_id, parent=new_parent) + new_child = SDChild(id=new_child_id, parent_id=new_parent.id) new_child.deleted_at = pseudorandom_date() if child_deleted else None s.add(new_child) s.flush() @@ -45,7 +45,7 @@ def generate_parent_child_object_hierarchy( for child_children_no, child_children_deleted in enumerate(child_children): sdchild_child_id = new_child_id * 100 + child_children_no - new_child_child = SDChildChild(id=sdchild_child_id, child=new_child) + new_child_child = SDChildChild(id=sdchild_child_id, child_id=new_child.id) child_child_deleted = pseudorandom_date() if child_children_deleted else None new_child_child.deleted_at = child_child_deleted s.add(new_child_child) diff --git a/tests/snapshots/snap_test_queries.py b/tests/snapshots/snap_test_queries.py index 293c9d1..adef08f 100644 --- a/tests/snapshots/snap_test_queries.py +++ b/tests/snapshots/snap_test_queries.py @@ -65,6 +65,12 @@ GenericRepr('') ] +snapshots['test_query_union_sdchild_core 1'] = '''SELECT sdchild.id, sdchild.parent_id +FROM sdchild +WHERE sdchild.deleted_at IS NULL UNION SELECT sdchild.id, sdchild.parent_id +FROM sdchild +WHERE sdchild.deleted_at IS NULL''' + snapshots['test_query_with_join 1'] = '''SELECT sdchild.id, sdchild.deleted_at, sdchild.parent_id FROM sdchild JOIN sdparent ON sdparent.id = sdchild.parent_id WHERE sdchild.deleted_at IS NULL AND sdparent.deleted_at IS NULL''' diff --git a/tests/test_queries.py b/tests/test_queries.py index 6ecb368..991055b 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -13,7 +13,7 @@ def test_query_single_table(snapshot, seeded_session, rewriter): """Query with one table""" test_query: Query = seeded_session.query(SDChild) - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) @@ -21,7 +21,7 @@ def test_query_with_join(snapshot, seeded_session, rewriter): """Query with a simple join""" test_query: Query = seeded_session.query(SDChild).join(SDParent) # noqa -- wrong typing stub in SA - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) @@ -30,11 +30,22 @@ def test_query_union_sdchild(snapshot, seeded_session, rewriter): """Two queries joined via UNION""" test_query: Query = seeded_session.query(SDChild).union(seeded_session.query(SDChild)) - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) +def test_query_union_sdchild_core(snapshot, seeded_session, rewriter): + """Two queries joined via UNION, using SQLAlchemy Core""" + sdchild = SDChild.__table__ + + select_as_core = (select(sdchild.c.id, sdchild.c.parent_id).select_from(sdchild)).union( + select(sdchild.c.id, sdchild.c.parent_id).select_from(sdchild) + ) + + snapshot.assert_match(str(rewriter.rewrite_statement(select_as_core))) + + def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session, rewriter): """Two queries joined via UNION but the second one has soft-delete disabled""" @@ -47,7 +58,7 @@ def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session seeded_session.query(SDChild).execution_options(include_deleted=True) ) - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) all_children: List[SDChild] = seeded_session.query(SDChild).execution_options(include_deleted=True).all() @@ -60,14 +71,14 @@ def test_ensure_aggregate_from_multiple_table_deletion_works_active_object_count """Aggregate function from a query that contains a join""" test_query: Query = seeded_session.query(SDChild).join(SDParent).with_entities(func.count()) # noqa - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) snapshot.assert_match(test_query.count()) def test_ensure_table_with_inheritance_works(snapshot, seeded_session, rewriter): test_query: Query = seeded_session.query(SDDerivedRequest) - snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement))) test_query_results = test_query.all() assert len(test_query_results) == 2 @@ -105,7 +116,7 @@ def test_query_with_text_clause_as_table(snapshot, seeded_session, rewriter): # Table as a TextClause test_query_text_clause: Select = select(text('id')).select_from(text("sdderivedrequest")) - snapshot.assert_match(str(rewriter.rewrite_select(test_query_text_clause))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query_text_clause))) def test_query_with_table_clause_as_table(snapshot, seeded_session, rewriter): @@ -113,7 +124,7 @@ def test_query_with_table_clause_as_table(snapshot, seeded_session, rewriter): # Table as a TableClause test_query_table_clause: Select = select(text('id')).select_from(table("sdderivedrequest")) - snapshot.assert_match(str(rewriter.rewrite_select(test_query_table_clause))) + snapshot.assert_match(str(rewriter.rewrite_statement(test_query_table_clause))) def test_insert_with_returning(snapshot, seeded_session, rewriter, db_connection): @@ -144,4 +155,4 @@ def test_query_with_more_than_one_join(snapshot, seeded_session, rewriter): ) ) - snapshot.assert_match(str(rewriter.rewrite_select(query.statement))) + snapshot.assert_match(str(rewriter.rewrite_statement(query.statement)))