Skip to content

Commit

Permalink
Support .union() queries using SQLAlchemy Core (#13)
Browse files Browse the repository at this point in the history
* Make library support CompoundSelect

* Simplified some call sites for SQLAlchemy 2.0

* Updated Makefile
  • Loading branch information
flipbit03 authored Apr 10, 2023
1 parent 86a202b commit bc572e4
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.2
current_version = 0.6.4
commit = True
tag = False

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build

on:
push:
branches: [ main, '*' ]
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
Expand Down
18 changes: 10 additions & 8 deletions makefile → Makefile
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
sources = sqlalchemy_easy_softdelete

.PHONY: test format lint unittest coverage pre-commit clean
test: lint unittest

.PHONY: lint test coverage clean
lint:
flake8 $(sources) tests
pre-commit run --all-files

unittest:
test:
pytest

coverage:
pytest --cov=$(sources) --cov-branch --cov-report=term-missing --cov-report=xml tests

pre-commit:
pre-commit run --all-files

clean:
rm -rf .pytest_cache
rm -rf *.egg-info
rm -rf .tox dist site
rm -rf coverage.xml .coverage

dev:
# Start Postgres Instance
docker compose up -d pg

bump_patch:
bump2version patch --no-tag
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "sqlalchemy-easy-softdelete"
version = "0.6.2"
version = "0.6.4"
homepage = "https://github.com/flipbit03/sqlalchemy-easy-softdelete"
description = "Easily add soft-deletion to your SQLAlchemy Models."
authors = ["Cadu <cadu.coelho@gmail.com>"]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
max-line-length = 120
max-complexity = 18
ignore = E203, E266, W503, TYP001
ignore = E203, E266, W503, TYP001, D202
docstring-convention = pep257
per-file-ignores = __init__.py:F401
exclude = .git,
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_easy_softdelete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__author__ = """Cadu"""
__email__ = 'cadu.coelho@gmail.com'
__version__ = '0.6.2'
__version__ = '0.6.4'
11 changes: 9 additions & 2 deletions sqlalchemy_easy_softdelete/handler/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from sqlalchemy import Table
from sqlalchemy.orm import FromStatement
from sqlalchemy.orm.util import _ORMJoin
from sqlalchemy.sql import Alias, CompoundSelect, Join, Select, Subquery, TableClause
from sqlalchemy.sql import Alias, CompoundSelect, Executable, Join, Select, Subquery, TableClause
from sqlalchemy.sql.elements import TextClause

Statement = TypeVar('Statement', bound=Union[Select, FromStatement])
Statement = TypeVar('Statement', bound=Union[Select, FromStatement, CompoundSelect, Executable])


class SoftDeleteQueryRewriter:
Expand Down Expand Up @@ -37,6 +37,11 @@ def rewrite_statement(self, stmt: Statement) -> Statement:
if isinstance(stmt, Select):
return self.rewrite_select(stmt)

# Handle CompoundSelect
if isinstance(stmt, CompoundSelect):
return self.rewrite_compound_select(stmt)

# Handle FromStatement which is also a Select/Executable
if isinstance(stmt, FromStatement):
# Explicitly protect against INSERT with RETURNING
if not isinstance(stmt.element, Select):
Expand Down Expand Up @@ -80,6 +85,8 @@ def rewrite_element(self, subquery: Subquery) -> Subquery:
raise NotImplementedError(f"Unsupported object \"{(type(subquery.element))}\" in subquery.element")

def rewrite_from_orm_join(self, stmt: Select, join_obj: Union[_ORMJoin, Join]) -> Select:
"""Handle multiple, and potentially recursive joins."""

# Recursive cases (multiple joins)
if isinstance(join_obj.left, _ORMJoin) or isinstance(join_obj.left, Join):
stmt = self.rewrite_from_orm_join(stmt, join_obj.left)
Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@

env_connection_string = os.environ.get("TEST_CONNECTION_STRING", None)

test_db_url = env_connection_string or "sqlite://"

@pytest.fixture
def sqla2_warnings() -> Engine:
# Enable SQLAlchemy 2.0 Warnings mode to help with 2.0 support
os.environ["SQLALCHEMY_WARN_20"] = "1"


@pytest.fixture
def db_engine() -> Engine:
def db_engine(sqla2_warnings) -> Engine:
test_db_url = env_connection_string or "sqlite://"
print(f"connection_string={test_db_url}")
return create_engine(test_db_url)
return create_engine(test_db_url, future=True)


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/seed_data/parent_child_childchild.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate_parent_child_object_hierarchy(
# Create Children (SDChild)
for child_no, child_deleted in enumerate(children):
new_child_id = parent_id * 100 + child_no
new_child = SDChild(id=new_child_id, parent=new_parent)
new_child = SDChild(id=new_child_id, parent_id=new_parent.id)
new_child.deleted_at = pseudorandom_date() if child_deleted else None
s.add(new_child)
s.flush()
Expand All @@ -45,7 +45,7 @@ def generate_parent_child_object_hierarchy(

for child_children_no, child_children_deleted in enumerate(child_children):
sdchild_child_id = new_child_id * 100 + child_children_no
new_child_child = SDChildChild(id=sdchild_child_id, child=new_child)
new_child_child = SDChildChild(id=sdchild_child_id, child_id=new_child.id)
child_child_deleted = pseudorandom_date() if child_children_deleted else None
new_child_child.deleted_at = child_child_deleted
s.add(new_child_child)
Expand Down
6 changes: 6 additions & 0 deletions tests/snapshots/snap_test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
GenericRepr('<SDChild id=100204 deleted=False (parent_id=1002)>')
]

snapshots['test_query_union_sdchild_core 1'] = '''SELECT sdchild.id, sdchild.parent_id
FROM sdchild
WHERE sdchild.deleted_at IS NULL UNION SELECT sdchild.id, sdchild.parent_id
FROM sdchild
WHERE sdchild.deleted_at IS NULL'''

snapshots['test_query_with_join 1'] = '''SELECT sdchild.id, sdchild.deleted_at, sdchild.parent_id
FROM sdchild JOIN sdparent ON sdparent.id = sdchild.parent_id
WHERE sdchild.deleted_at IS NULL AND sdparent.deleted_at IS NULL'''
Expand Down
29 changes: 20 additions & 9 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_query_single_table(snapshot, seeded_session, rewriter):
"""Query with one table"""
test_query: Query = seeded_session.query(SDChild)

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))
snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id))


def test_query_with_join(snapshot, seeded_session, rewriter):
"""Query with a simple join"""
test_query: Query = seeded_session.query(SDChild).join(SDParent) # noqa -- wrong typing stub in SA

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))

snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id))

Expand All @@ -30,11 +30,22 @@ def test_query_union_sdchild(snapshot, seeded_session, rewriter):
"""Two queries joined via UNION"""
test_query: Query = seeded_session.query(SDChild).union(seeded_session.query(SDChild))

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))

snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id))


def test_query_union_sdchild_core(snapshot, seeded_session, rewriter):
"""Two queries joined via UNION, using SQLAlchemy Core"""
sdchild = SDChild.__table__

select_as_core = (select(sdchild.c.id, sdchild.c.parent_id).select_from(sdchild)).union(
select(sdchild.c.id, sdchild.c.parent_id).select_from(sdchild)
)

snapshot.assert_match(str(rewriter.rewrite_statement(select_as_core)))


def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session, rewriter):
"""Two queries joined via UNION but the second one has soft-delete disabled"""

Expand All @@ -47,7 +58,7 @@ def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session
seeded_session.query(SDChild).execution_options(include_deleted=True)
)

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))

all_children: List[SDChild] = seeded_session.query(SDChild).execution_options(include_deleted=True).all()

Expand All @@ -60,14 +71,14 @@ def test_ensure_aggregate_from_multiple_table_deletion_works_active_object_count
"""Aggregate function from a query that contains a join"""
test_query: Query = seeded_session.query(SDChild).join(SDParent).with_entities(func.count()) # noqa

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))
snapshot.assert_match(test_query.count())


def test_ensure_table_with_inheritance_works(snapshot, seeded_session, rewriter):
test_query: Query = seeded_session.query(SDDerivedRequest)

snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query.statement)))

test_query_results = test_query.all()
assert len(test_query_results) == 2
Expand Down Expand Up @@ -105,15 +116,15 @@ def test_query_with_text_clause_as_table(snapshot, seeded_session, rewriter):

# Table as a TextClause
test_query_text_clause: Select = select(text('id')).select_from(text("sdderivedrequest"))
snapshot.assert_match(str(rewriter.rewrite_select(test_query_text_clause)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query_text_clause)))


def test_query_with_table_clause_as_table(snapshot, seeded_session, rewriter):
"""We cannot parse information from a literal text table name -- return unchanged"""

# Table as a TableClause
test_query_table_clause: Select = select(text('id')).select_from(table("sdderivedrequest"))
snapshot.assert_match(str(rewriter.rewrite_select(test_query_table_clause)))
snapshot.assert_match(str(rewriter.rewrite_statement(test_query_table_clause)))


def test_insert_with_returning(snapshot, seeded_session, rewriter, db_connection):
Expand Down Expand Up @@ -144,4 +155,4 @@ def test_query_with_more_than_one_join(snapshot, seeded_session, rewriter):
)
)

snapshot.assert_match(str(rewriter.rewrite_select(query.statement)))
snapshot.assert_match(str(rewriter.rewrite_statement(query.statement)))

0 comments on commit bc572e4

Please sign in to comment.