diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 66b998a..ae88ddd 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.5.0 +current_version = 0.5.1 commit = True tag = True diff --git a/.env.docker b/.env.docker new file mode 100644 index 0000000..8745f08 --- /dev/null +++ b/.env.docker @@ -0,0 +1 @@ +TEST_CONNECTION_STRING=postgresql://postgres:postgres@pg:5432/test_db diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cfcb9ca..a9b6d66 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,7 +2,7 @@ name: Build on: push: - branches: [ main, dev/* ] + branches: [ main, '*' ] pull_request: branches: [ main ] workflow_dispatch: @@ -16,10 +16,11 @@ jobs: - uses: actions/checkout@v3 - name: Build Docker image run: | - docker compose build tests-with-coverage --no-cache + docker compose build tests-with-coverage --quiet + docker compose pull - name: Run Tests via Docker run: | - docker compose up --exit-code-from tests-with-coverage tests-with-coverage + docker compose --env-file .env.docker run tests-with-coverage - name: Show Test Logs if tests failed if: ${{ failure() }} run: docker compose logs diff --git a/Dockerfile b/Dockerfile index 52aac40..cb294f1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,4 +26,4 @@ RUN poetry install -E test FROM content as testing_and_coverage -CMD poetry run pytest --cov=sqlalchemy_easy_softdelete --cov-branch --cov-report=term-missing --cov-report=xml tests +CMD sleep 2 && poetry run pytest --cov=sqlalchemy_easy_softdelete --cov-branch --cov-report=term-missing --cov-report=xml tests diff --git a/docker-compose.yml b/docker-compose.yml index 9cef6e0..6e3267c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,10 @@ services: # Test Runner ############################## tests: + depends_on: + - pg + env_file: + - .env.docker environment: - PYTHONUNBUFFERED=1 build: @@ -13,6 +17,28 @@ services: ############################## tests-with-coverage: extends: "tests" + # Set up volume so that coverage information can be relayed back to the outside volumes: - "./:/library" + + ############################## + # PostgreSQL Instance + ############################## + pg: + image: postgres:14 + volumes: + - pg_db_data:/var/lib/postgresql/data + ports: + - "9991:5432" + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test_db + logging: + options: + max-size: "1m" + + +volumes: + pg_db_data: diff --git a/makefile b/makefile index 4459a09..18f39d8 100644 --- a/makefile +++ b/makefile @@ -1,11 +1,7 @@ sources = sqlalchemy_easy_softdelete .PHONY: test format lint unittest coverage pre-commit clean -test: format lint unittest - -format: - isort $(sources) tests - black $(sources) tests +test: lint unittest lint: flake8 $(sources) tests diff --git a/poetry.lock b/poetry.lock index 2b1b8f3..5c98dd7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -586,6 +586,14 @@ python-versions = ">=3.6.2" [package.dependencies] wcwidth = "*" +[[package]] +name = "psycopg2" +version = "2.9.4" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +category = "dev" +optional = false +python-versions = ">=3.6" + [[package]] name = "ptyprocess" version = "0.7.0" @@ -1059,7 +1067,7 @@ test = ["pytest", "black", "isort", "flake8", "flake8-docstrings", "pytest-cov"] [metadata] lock-version = "1.1" python-versions = ">=3.10,<4.0" -content-hash = "efb614c3ce5aceff4dd63bbf42b26ae11a482fac5fddb09f721c4f9be9a8ad4c" +content-hash = "49fbf35ce105392b0133d8c930b4f6304024c0b9bfad7c4d2c4aea6574de5bff" [metadata.files] appnope = [ @@ -1443,6 +1451,19 @@ prompt-toolkit = [ {file = "prompt_toolkit-3.0.31-py3-none-any.whl", hash = "sha256:9696f386133df0fc8ca5af4895afe5d78f5fcfe5258111c2a79a1c3e41ffa96d"}, {file = "prompt_toolkit-3.0.31.tar.gz", hash = "sha256:9ada952c9d1787f52ff6d5f3484d0b4df8952787c087edf6a1f7c2cb1ea88148"}, ] +psycopg2 = [ + {file = "psycopg2-2.9.4-cp310-cp310-win32.whl", hash = "sha256:8de6a9fc5f42fa52f559e65120dcd7502394692490c98fed1221acf0819d7797"}, + {file = "psycopg2-2.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:1da77c061bdaab450581458932ae5e469cc6e36e0d62f988376e9f513f11cb5c"}, + {file = "psycopg2-2.9.4-cp36-cp36m-win32.whl", hash = "sha256:a11946bad3557ca254f17357d5a4ed63bdca45163e7a7d2bfb8e695df069cc3a"}, + {file = "psycopg2-2.9.4-cp36-cp36m-win_amd64.whl", hash = "sha256:46361c054df612c3cc813fdb343733d56543fb93565cff0f8ace422e4da06acb"}, + {file = "psycopg2-2.9.4-cp37-cp37m-win32.whl", hash = "sha256:aafa96f2da0071d6dd0cbb7633406d99f414b40ab0f918c9d9af7df928a1accb"}, + {file = "psycopg2-2.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:aa184d551a767ad25df3b8d22a0a62ef2962e0e374c04f6cbd1204947f540d61"}, + {file = "psycopg2-2.9.4-cp38-cp38-win32.whl", hash = "sha256:839f9ea8f6098e39966d97fcb8d08548fbc57c523a1e27a1f0609addf40f777c"}, + {file = "psycopg2-2.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:c7fa041b4acb913f6968fce10169105af5200f296028251d817ab37847c30184"}, + {file = "psycopg2-2.9.4-cp39-cp39-win32.whl", hash = "sha256:07b90a24d5056687781ddaef0ea172fd951f2f7293f6ffdd03d4f5077801f426"}, + {file = "psycopg2-2.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:849bd868ae3369932127f0771c08d1109b254f08d48dc42493c3d1b87cb2d308"}, + {file = "psycopg2-2.9.4.tar.gz", hash = "sha256:d529926254e093a1b669f692a3aa50069bc71faf5b0ecd91686a78f62767d52f"}, +] ptyprocess = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, diff --git a/pyproject.toml b/pyproject.toml index ac9d5d7..c04b257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool] [tool.poetry] name = "sqlalchemy-easy-softdelete" -version = "0.5.0" +version = "0.5.1" homepage = "https://github.com/flipbit03/sqlalchemy-easy-softdelete" description = "Easily add soft-deletion to your SQLAlchemy Models." authors = ["Cadu "] @@ -40,6 +40,7 @@ bump2version = {version = "^1.0.1", optional = true} [tool.poetry.dev-dependencies] ipython = "^8.4.0" snapshottest = "^0.6.0" +psycopg2 = "^2.9.4" [tool.poetry.extras] test = [ diff --git a/setup.cfg b/setup.cfg index 8aa7877..384b46d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ exclude = .git, .pytest_cache, .vscode, .github, - ./tests/snapshots/* + ./tests/* # By default test codes will be linted. # tests diff --git a/sqlalchemy_easy_softdelete/__init__.py b/sqlalchemy_easy_softdelete/__init__.py index 28b284c..9fa780f 100644 --- a/sqlalchemy_easy_softdelete/__init__.py +++ b/sqlalchemy_easy_softdelete/__init__.py @@ -2,4 +2,4 @@ __author__ = """Cadu""" __email__ = 'cadu.coelho@gmail.com' -__version__ = '0.5.0' +__version__ = '0.5.1' diff --git a/sqlalchemy_easy_softdelete/handler/__init__.py b/sqlalchemy_easy_softdelete/handler/__init__.py index e69de29..18302ed 100644 --- a/sqlalchemy_easy_softdelete/handler/__init__.py +++ b/sqlalchemy_easy_softdelete/handler/__init__.py @@ -0,0 +1 @@ +"""Group of functions related to the query rewriting process.""" diff --git a/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py b/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py index bb58715..6c4b560 100644 --- a/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py +++ b/sqlalchemy_easy_softdelete/handler/rewriter/__init__.py @@ -1,3 +1,5 @@ +"""Main query rewriter logic.""" + from typing import TypeVar, Union from sqlalchemy import Table @@ -10,21 +12,42 @@ class SoftDeleteQueryRewriter: + """Rewrites SQL statements based on configuration.""" + def __init__(self, deleted_field_name: str, disable_soft_delete_option_name: str): + """ + Instantiate a new query rewriter. + + Params: + + deleted_field_name: + The name of the field that should be present in a table for soft-deletion + rewriting to occur + + disable_soft_delete_option_name: + Execution option name (to use with .execution_options(xxxx=True) to disable + soft deletion rewriting in a query + + """ self.deleted_field_name = deleted_field_name self.disable_soft_delete_option_name = disable_soft_delete_option_name def rewrite_statement(self, stmt: Statement) -> Statement: + """Rewrite a single SQL-like Statement.""" if isinstance(stmt, Select): return self.rewrite_select(stmt) if isinstance(stmt, FromStatement): + # Explicitly protect against INSERT with RETURNING + if not isinstance(stmt.element, Select): + return stmt stmt.element = self.rewrite_select(stmt.element) return stmt raise NotImplementedError(f"Unsupported statement type \"{(type(stmt))}\"!") def rewrite_select(self, stmt: Select) -> Select: + """Rewrite a Select Statement.""" # if the user tagged this query with an execution_option to disable soft-delete filtering # simply return back the same stmt if stmt.get_execution_options().get(self.disable_soft_delete_option_name): @@ -36,6 +59,7 @@ def rewrite_select(self, stmt: Select) -> Select: return stmt def rewrite_compound_select(self, stmt: CompoundSelect) -> CompoundSelect: + """Rewrite a Compound Select Statement.""" # This needs to be done by array slice referencing instead of # a direct reassignment because the reassignment would not substitute the # value which is inside the CompoundSelect "by reference" @@ -44,6 +68,7 @@ def rewrite_compound_select(self, stmt: CompoundSelect) -> CompoundSelect: return stmt def rewrite_element(self, subquery: Subquery) -> Subquery: + """Rewrite an object with a `.element` attribute and patch the query inside it.""" if isinstance(subquery.element, CompoundSelect): subquery.element = self.rewrite_compound_select(subquery.element) return subquery @@ -55,6 +80,7 @@ def rewrite_element(self, subquery: Subquery) -> Subquery: raise NotImplementedError(f"Unsupported object \"{(type(subquery.element))}\" in subquery.element") def analyze_from(self, stmt: Select, from_obj): + """Analyze the FROMS of a Select to determine possible soft-delete rewritable tables.""" if isinstance(from_obj, Table): return self.rewrite_from_table(stmt, from_obj) @@ -85,6 +111,7 @@ def analyze_from(self, stmt: Select, from_obj): raise NotImplementedError(f"Unsupported object \"{(type(from_obj))}\" in statement.froms") def rewrite_from_table(self, stmt: Select, table: Table) -> Select: + """(possibly) Rewrite a Select based on whether the Table contains the soft-delete field or not.""" column_obj = table.columns.get(self.deleted_field_name) # Caveat: The automatic "bool(column_obj)" conversion actually returns diff --git a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py index c9fe127..481747a 100644 --- a/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py +++ b/sqlalchemy_easy_softdelete/handler/sqlalchemy_easy_softdelete.py @@ -1,3 +1,5 @@ +"""This module is responsible for activating the query rewriter.""" + from functools import cache from sqlalchemy.event import listens_for @@ -8,6 +10,7 @@ @cache def activate_soft_delete_hook(deleted_field_name: str, disable_soft_delete_option_name: str): + """Activate an event hook to rewrite the queries.""" # Enable Soft Delete on all Relationship Loads which implement SoftDeleteMixin @listens_for(Session, "do_orm_execute") def soft_delete_execute(state: ORMExecuteState): diff --git a/sqlalchemy_easy_softdelete/mixin.py b/sqlalchemy_easy_softdelete/mixin.py index e6da12d..53771bd 100644 --- a/sqlalchemy_easy_softdelete/mixin.py +++ b/sqlalchemy_easy_softdelete/mixin.py @@ -1,3 +1,5 @@ +"""Functions related to dynamic generation of the soft-delete mixin.""" + from datetime import datetime from typing import Any, Callable, Optional, Type @@ -18,6 +20,7 @@ def generate_soft_delete_mixin_class( generate_undelete_method: bool = True, undelete_method_name: str = "undelete", ) -> Type: + """Generate the actual soft-delete Mixin class.""" class_attributes = {deleted_field_name: Column(deleted_field_name, deleted_field_type)} if generate_delete_method: diff --git a/tests/conftest.py b/tests/conftest.py index e7acb30..d07ee21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,83 +1,52 @@ -import datetime -import random +import os import pytest from sqlalchemy import create_engine +from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter -from tests.model import SDChild, SDDerivedRequest, SDParent, TestModelBase +from tests.model import TestModelBase +from tests.seed_data import generate_parent_child_object_hierarchy, generate_table_with_inheritance_obj -test_db_url = 'sqlite://' # use in-memory database for tests +test_db_url = os.environ.get("TEST_CONNECTION_STRING", "sqlite://") -@pytest.fixture(scope="function") -def session_factory(): - engine = create_engine(test_db_url) - TestModelBase.metadata.create_all(engine) - - yield sessionmaker(bind=engine) - - # SQLite in-memory db is deleted when its connection is closed. - # https://www.sqlite.org/inmemorydb.html - engine.dispose() - - -@pytest.fixture(scope="function") -def session(session_factory) -> Session: - return session_factory() - - -def generate_parent_child_object_hierarchy( - s: Session, parent_id: int, min_children: int = 1, max_children: int = 3, parent_deleted: bool = False -): - # Fix a seed in the RNG for deterministic outputs - random.seed(parent_id) - - # Generate the Parent - deleted_at = datetime.datetime.utcnow() if parent_deleted else None - new_parent = SDParent(id=parent_id, deleted_at=deleted_at) - s.add(new_parent) - s.flush() - - active_children = random.randint(min_children, max_children) +@pytest.fixture +def db_engine() -> Engine: + return create_engine(test_db_url) - # Add some active children - for active_id in range(active_children): - new_child = SDChild(id=parent_id * 1000 + active_id, parent=new_parent) - s.add(new_child) - s.flush() - # Add some soft-deleted children - for inactive_id in range(random.randint(min_children, max_children)): - new_soft_deleted_child = SDChild( - id=parent_id * 1000 + active_children + inactive_id, - parent=new_parent, - deleted_at=datetime.datetime.utcnow(), - ) - s.add(new_soft_deleted_child) - s.flush() +@pytest.fixture +def db_connection(db_engine) -> Connection: + connection = db_engine.connect() - s.commit() + # start a transaction + transaction = connection.begin() + try: + yield connection + finally: + transaction.rollback() + connection.close() -def generate_table_with_inheritance_obj(s: Session, obj_id: int, deleted: bool = False): - deleted_at = datetime.datetime.utcnow() if deleted else None - new_parent = SDDerivedRequest(id=obj_id, deleted_at=deleted_at) - s.add(new_parent) - s.commit() +@pytest.fixture +def db_session(db_connection) -> Session: + TestModelBase.metadata.create_all(db_connection) + return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)() -@pytest.fixture(scope="function") -def seeded_session(session) -> Session: - generate_parent_child_object_hierarchy(session, 0) - generate_parent_child_object_hierarchy(session, 1) - generate_parent_child_object_hierarchy(session, 2, parent_deleted=True) - generate_table_with_inheritance_obj(session, 0, deleted=False) - generate_table_with_inheritance_obj(session, 1, deleted=False) - generate_table_with_inheritance_obj(session, 2, deleted=True) - return session +@pytest.fixture +def seeded_session(db_session) -> Session: + generate_parent_child_object_hierarchy(db_session, 1000) + generate_parent_child_object_hierarchy(db_session, 1001) + generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True) + + generate_table_with_inheritance_obj(db_session, 1000, deleted=False) + generate_table_with_inheritance_obj(db_session, 1001, deleted=False) + generate_table_with_inheritance_obj(db_session, 1002, deleted=True) + return db_session @pytest.fixture diff --git a/tests/model.py b/tests/model.py index d069725..1f042b6 100644 --- a/tests/model.py +++ b/tests/model.py @@ -13,7 +13,7 @@ class TestModelBase: def __tablename__(cls) -> str: return cls.__name__.lower() - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) def __repr__(self): return f"<{self.__class__.__name__} id={self.id}>" @@ -24,6 +24,14 @@ class SoftDeleteMixin(generate_soft_delete_mixin_class()): deleted_at: datetime +class SDSimpleTable(TestModelBase, SoftDeleteMixin): + + int_field = Column(Integer) + + def __repr__(self): + return f"<{self.__class__.__name__} id={self.id} deleted={bool(self.deleted_at)}>" + + class SDParent(TestModelBase, SoftDeleteMixin): children: 'List[SDChild]' = relationship('SDChild') @@ -47,6 +55,8 @@ class SDBaseRequest( ): request_type = Column(String(50)) + base_field = Column(Integer) + __mapper_args__ = { "polymorphic_identity": "sdbaserequest", "polymorphic_on": request_type, @@ -56,7 +66,7 @@ class SDBaseRequest( class SDDerivedRequest(SDBaseRequest): id: Integer = Column(Integer, ForeignKey("sdbaserequest.id"), primary_key=True) - int_field = Column(Integer) + derived_field = Column(Integer) __mapper_args__ = { "polymorphic_identity": "sdderivedrequest", diff --git a/tests/seed_data.py b/tests/seed_data.py new file mode 100644 index 0000000..2ef14fa --- /dev/null +++ b/tests/seed_data.py @@ -0,0 +1,46 @@ +import random +from datetime import datetime + +from sqlalchemy.orm import Session + +from tests.model import SDChild, SDDerivedRequest, SDParent + + +def generate_parent_child_object_hierarchy( + s: Session, parent_id: int, min_children: int = 1, max_children: int = 5, parent_deleted: bool = False +): + # Fix a seed in the RNG for deterministic outputs + random.seed(parent_id) + + # Generate the Parent + deleted_at = datetime.utcnow() if parent_deleted else None + new_parent = SDParent(id=parent_id, deleted_at=deleted_at) + s.add(new_parent) + s.flush() + + active_children = random.randint(min_children, max_children) + + # Add some active children + for active_id in range(active_children): + new_child = SDChild(id=parent_id * 1000 + active_id, parent=new_parent) + s.add(new_child) + s.flush() + + # Add some soft-deleted children + for inactive_id in range(random.randint(min_children, max_children)): + new_soft_deleted_child = SDChild( + id=parent_id * 1000 + active_children + inactive_id, + parent=new_parent, + deleted_at=datetime.utcnow(), + ) + s.add(new_soft_deleted_child) + s.flush() + + s.commit() + + +def generate_table_with_inheritance_obj(s: Session, obj_id: int, deleted: bool = False): + deleted_at = datetime.utcnow() if deleted else None + new_parent = SDDerivedRequest(id=obj_id, deleted_at=deleted_at) + s.add(new_parent) + s.commit() diff --git a/tests/snapshots/snap_test_queries.py b/tests/snapshots/snap_test_queries.py index 4e63fe6..e3834b3 100644 --- a/tests/snapshots/snap_test_queries.py +++ b/tests/snapshots/snap_test_queries.py @@ -4,7 +4,6 @@ from snapshottest import GenericRepr, Snapshot - snapshots = Snapshot() snapshots['test_ensure_aggregate_from_multiple_table_deletion_works_active_object_count 1'] = '''SELECT count(*) AS count_1 @@ -13,19 +12,19 @@ snapshots['test_ensure_aggregate_from_multiple_table_deletion_works_active_object_count 2'] = 1 -snapshots['test_ensure_table_with_inheritance_works 1'] = '''SELECT sdderivedrequest.id, sdbaserequest.id AS id_1, sdbaserequest.deleted_at, sdbaserequest.request_type, sdderivedrequest.int_field +snapshots['test_ensure_table_with_inheritance_works 1'] = '''SELECT sdderivedrequest.id, sdbaserequest.id AS id_1, sdbaserequest.deleted_at, sdbaserequest.request_type, sdbaserequest.base_field, sdderivedrequest.derived_field FROM sdbaserequest JOIN sdderivedrequest ON sdbaserequest.id = sdderivedrequest.id WHERE sdbaserequest.deleted_at IS NULL''' snapshots['test_ensure_table_with_inheritance_works 2'] = [ - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr('') ] snapshots['test_ensure_table_with_inheritance_works 3'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] snapshots['test_query_single_table 1'] = '''SELECT sdchild.id, sdchild.deleted_at, sdchild.parent_id @@ -33,10 +32,16 @@ WHERE sdchild.deleted_at IS NULL''' snapshots['test_query_single_table 2'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] snapshots['test_query_union_sdchild 1'] = '''SELECT anon_1.sdchild_id, anon_1.sdchild_deleted_at, anon_1.sdchild_parent_id @@ -47,10 +52,16 @@ WHERE sdchild.deleted_at IS NULL) AS anon_1''' snapshots['test_query_union_sdchild 2'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] snapshots['test_query_with_join 1'] = '''SELECT sdchild.id, sdchild.deleted_at, sdchild.parent_id @@ -58,9 +69,11 @@ WHERE sdchild.deleted_at IS NULL AND sdparent.deleted_at IS NULL''' snapshots['test_query_with_join 2'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] snapshots['test_query_with_table_clause_as_table 1'] = '''SELECT id @@ -76,14 +89,22 @@ FROM sdchild) AS anon_1''' snapshots['test_query_with_union_but_union_softdelete_disabled 2'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] diff --git a/tests/snapshots/snap_test_seed_data.py b/tests/snapshots/snap_test_seed_data.py index b1e102f..51efe04 100644 --- a/tests/snapshots/snap_test_seed_data.py +++ b/tests/snapshots/snap_test_seed_data.py @@ -4,24 +4,31 @@ from snapshottest import GenericRepr, Snapshot - snapshots = Snapshot() snapshots['test_ensure_stable_seed_data 1'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] snapshots['test_ensure_stable_seed_data 2'] = [ - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr(''), - GenericRepr('') + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr(''), + GenericRepr('') ] diff --git a/tests/test_queries.py b/tests/test_queries.py index f2943d0..8377581 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,11 +1,12 @@ """Tests for `sqlalchemy_easy_softdelete` package.""" from typing import List -from sqlalchemy import func, select, table, text +import pytest +from sqlalchemy import func, insert, select, table, text from sqlalchemy.orm import Query from sqlalchemy.sql import Select -from tests.model import SDBaseRequest, SDChild, SDDerivedRequest, SDParent +from tests.model import SDBaseRequest, SDChild, SDDerivedRequest, SDParent, SDSimpleTable def test_query_single_table(snapshot, seeded_session, rewriter): @@ -13,7 +14,7 @@ def test_query_single_table(snapshot, seeded_session, rewriter): test_query: Query = seeded_session.query(SDChild) snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) - snapshot.assert_match(test_query.all()) + snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) def test_query_with_join(snapshot, seeded_session, rewriter): @@ -22,7 +23,7 @@ def test_query_with_join(snapshot, seeded_session, rewriter): snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) - snapshot.assert_match(test_query.all()) + snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) def test_query_union_sdchild(snapshot, seeded_session, rewriter): @@ -31,7 +32,7 @@ def test_query_union_sdchild(snapshot, seeded_session, rewriter): snapshot.assert_match(str(rewriter.rewrite_select(test_query.statement))) - snapshot.assert_match(test_query.all()) + snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session, rewriter): @@ -52,7 +53,7 @@ def test_query_with_union_but_union_softdelete_disabled(snapshot, seeded_session assert sorted(test_query.all(), key=lambda x: x.id) == sorted(all_children, key=lambda x: x.id) - snapshot.assert_match(test_query.all()) + snapshot.assert_match(sorted(test_query.all(), key=lambda i: i.id)) def test_ensure_aggregate_from_multiple_table_deletion_works_active_object_count(snapshot, seeded_session, rewriter): @@ -70,14 +71,14 @@ def test_ensure_table_with_inheritance_works(snapshot, seeded_session, rewriter) test_query_results = test_query.all() assert len(test_query_results) == 2 - snapshot.assert_match(test_query_results) + snapshot.assert_match(sorted(test_query_results, key=lambda i: i.id)) all_active_and_deleted_derived_requests = ( seeded_session.query(SDDerivedRequest).execution_options(include_deleted=True).all() ) assert len(all_active_and_deleted_derived_requests) == 3 - snapshot.assert_match(all_active_and_deleted_derived_requests) + snapshot.assert_match(sorted(all_active_and_deleted_derived_requests, key=lambda i: i.id)) def test_ensure_table_with_inheritance_works_query_base(snapshot, seeded_session, rewriter): @@ -94,7 +95,7 @@ def test_ensure_table_with_inheritance_works_query_base(snapshot, seeded_session try: # Accessing a field in a SDDerived Request will trigger an additional query with # a `FromStatement` as the statement, instead of a normal Select - request.int_field + request.derived_field except Exception as exc: assert False, f"'Exception was raised {exc}" @@ -113,3 +114,21 @@ def test_query_with_table_clause_as_table(snapshot, seeded_session, rewriter): # Table as a TableClause test_query_table_clause: Select = select(text('id')).select_from(table("sdderivedrequest")) snapshot.assert_match(str(rewriter.rewrite_select(test_query_table_clause))) + + +def test_insert_with_returning(snapshot, seeded_session, rewriter, db_connection): + """Insert with RETURNING is considered a *Select* by SQLAlchemy, since it returns data :dizzy: + that means we need to actively protect against this case""" + + # RETURNING is not supported in SQLite + if db_connection.dialect.name == 'sqlite': + pytest.skip('SQLite does not support "INSERT...RETURNING"') + + insert_stmt = insert(SDSimpleTable).values(int_field=10).returning(SDSimpleTable) + + # Generate an Insert + RETURNING + insert_returning = select(SDSimpleTable).from_statement(insert_stmt) + + result = seeded_session.execute(insert_returning) + + assert list(result)[0][0].int_field == 10