Skip to content

Commit

Permalink
Protect against Inserts that can be identified as Selects (INSERT...R…
Browse files Browse the repository at this point in the history
…ETURNING) in SQLAlchemy (#7)

* Avoid trying to rewrite an INSERT...RETURNING

* Bump version: 0.5.0 → 0.5.1
  • Loading branch information
flipbit03 authored Oct 20, 2022
1 parent 6ceb5ed commit e94e2f3
Show file tree
Hide file tree
Showing 20 changed files with 282 additions and 130 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.0
current_version = 0.5.1
commit = True
tag = True

Expand Down
1 change: 1 addition & 0 deletions .env.docker
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TEST_CONNECTION_STRING=postgresql://postgres:postgres@pg:5432/test_db
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build

on:
push:
branches: [ main, dev/* ]
branches: [ main, '*' ]
pull_request:
branches: [ main ]
workflow_dispatch:
Expand All @@ -16,10 +16,11 @@ jobs:
- uses: actions/checkout@v3
- name: Build Docker image
run: |
docker compose build tests-with-coverage --no-cache
docker compose build tests-with-coverage --quiet
docker compose pull
- name: Run Tests via Docker
run: |
docker compose up --exit-code-from tests-with-coverage tests-with-coverage
docker compose --env-file .env.docker run tests-with-coverage
- name: Show Test Logs if tests failed
if: ${{ failure() }}
run: docker compose logs
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ RUN poetry install -E test

FROM content as testing_and_coverage

CMD poetry run pytest --cov=sqlalchemy_easy_softdelete --cov-branch --cov-report=term-missing --cov-report=xml tests
CMD sleep 2 && poetry run pytest --cov=sqlalchemy_easy_softdelete --cov-branch --cov-report=term-missing --cov-report=xml tests
26 changes: 26 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ services:
# Test Runner
##############################
tests:
depends_on:
- pg
env_file:
- .env.docker
environment:
- PYTHONUNBUFFERED=1
build:
Expand All @@ -13,6 +17,28 @@ services:
##############################
tests-with-coverage:
extends: "tests"

# Set up volume so that coverage information can be relayed back to the outside
volumes:
- "./:/library"

##############################
# PostgreSQL Instance
##############################
pg:
image: postgres:14
volumes:
- pg_db_data:/var/lib/postgresql/data
ports:
- "9991:5432"
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: test_db
logging:
options:
max-size: "1m"


volumes:
pg_db_data:
6 changes: 1 addition & 5 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
sources = sqlalchemy_easy_softdelete

.PHONY: test format lint unittest coverage pre-commit clean
test: format lint unittest

format:
isort $(sources) tests
black $(sources) tests
test: lint unittest

lint:
flake8 $(sources) tests
Expand Down
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "sqlalchemy-easy-softdelete"
version = "0.5.0"
version = "0.5.1"
homepage = "https://github.com/flipbit03/sqlalchemy-easy-softdelete"
description = "Easily add soft-deletion to your SQLAlchemy Models."
authors = ["Cadu <cadu.coelho@gmail.com>"]
Expand Down Expand Up @@ -40,6 +40,7 @@ bump2version = {version = "^1.0.1", optional = true}
[tool.poetry.dev-dependencies]
ipython = "^8.4.0"
snapshottest = "^0.6.0"
psycopg2 = "^2.9.4"

[tool.poetry.extras]
test = [
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ exclude = .git,
.pytest_cache,
.vscode,
.github,
./tests/snapshots/*
./tests/*
# By default test codes will be linted.
# tests

Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_easy_softdelete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__author__ = """Cadu"""
__email__ = 'cadu.coelho@gmail.com'
__version__ = '0.5.0'
__version__ = '0.5.1'
1 change: 1 addition & 0 deletions sqlalchemy_easy_softdelete/handler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Group of functions related to the query rewriting process."""
27 changes: 27 additions & 0 deletions sqlalchemy_easy_softdelete/handler/rewriter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Main query rewriter logic."""

from typing import TypeVar, Union

from sqlalchemy import Table
Expand All @@ -10,21 +12,42 @@


class SoftDeleteQueryRewriter:
"""Rewrites SQL statements based on configuration."""

def __init__(self, deleted_field_name: str, disable_soft_delete_option_name: str):
"""
Instantiate a new query rewriter.
Params:
deleted_field_name:
The name of the field that should be present in a table for soft-deletion
rewriting to occur
disable_soft_delete_option_name:
Execution option name (to use with .execution_options(xxxx=True) to disable
soft deletion rewriting in a query
"""
self.deleted_field_name = deleted_field_name
self.disable_soft_delete_option_name = disable_soft_delete_option_name

def rewrite_statement(self, stmt: Statement) -> Statement:
"""Rewrite a single SQL-like Statement."""
if isinstance(stmt, Select):
return self.rewrite_select(stmt)

if isinstance(stmt, FromStatement):
# Explicitly protect against INSERT with RETURNING
if not isinstance(stmt.element, Select):
return stmt
stmt.element = self.rewrite_select(stmt.element)
return stmt

raise NotImplementedError(f"Unsupported statement type \"{(type(stmt))}\"!")

def rewrite_select(self, stmt: Select) -> Select:
"""Rewrite a Select Statement."""
# if the user tagged this query with an execution_option to disable soft-delete filtering
# simply return back the same stmt
if stmt.get_execution_options().get(self.disable_soft_delete_option_name):
Expand All @@ -36,6 +59,7 @@ def rewrite_select(self, stmt: Select) -> Select:
return stmt

def rewrite_compound_select(self, stmt: CompoundSelect) -> CompoundSelect:
"""Rewrite a Compound Select Statement."""
# This needs to be done by array slice referencing instead of
# a direct reassignment because the reassignment would not substitute the
# value which is inside the CompoundSelect "by reference"
Expand All @@ -44,6 +68,7 @@ def rewrite_compound_select(self, stmt: CompoundSelect) -> CompoundSelect:
return stmt

def rewrite_element(self, subquery: Subquery) -> Subquery:
"""Rewrite an object with a `.element` attribute and patch the query inside it."""
if isinstance(subquery.element, CompoundSelect):
subquery.element = self.rewrite_compound_select(subquery.element)
return subquery
Expand All @@ -55,6 +80,7 @@ def rewrite_element(self, subquery: Subquery) -> Subquery:
raise NotImplementedError(f"Unsupported object \"{(type(subquery.element))}\" in subquery.element")

def analyze_from(self, stmt: Select, from_obj):
"""Analyze the FROMS of a Select to determine possible soft-delete rewritable tables."""
if isinstance(from_obj, Table):
return self.rewrite_from_table(stmt, from_obj)

Expand Down Expand Up @@ -85,6 +111,7 @@ def analyze_from(self, stmt: Select, from_obj):
raise NotImplementedError(f"Unsupported object \"{(type(from_obj))}\" in statement.froms")

def rewrite_from_table(self, stmt: Select, table: Table) -> Select:
"""(possibly) Rewrite a Select based on whether the Table contains the soft-delete field or not."""
column_obj = table.columns.get(self.deleted_field_name)

# Caveat: The automatic "bool(column_obj)" conversion actually returns
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""This module is responsible for activating the query rewriter."""

from functools import cache

from sqlalchemy.event import listens_for
Expand All @@ -8,6 +10,7 @@

@cache
def activate_soft_delete_hook(deleted_field_name: str, disable_soft_delete_option_name: str):
"""Activate an event hook to rewrite the queries."""
# Enable Soft Delete on all Relationship Loads which implement SoftDeleteMixin
@listens_for(Session, "do_orm_execute")
def soft_delete_execute(state: ORMExecuteState):
Expand Down
3 changes: 3 additions & 0 deletions sqlalchemy_easy_softdelete/mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Functions related to dynamic generation of the soft-delete mixin."""

from datetime import datetime
from typing import Any, Callable, Optional, Type

Expand All @@ -18,6 +20,7 @@ def generate_soft_delete_mixin_class(
generate_undelete_method: bool = True,
undelete_method_name: str = "undelete",
) -> Type:
"""Generate the actual soft-delete Mixin class."""
class_attributes = {deleted_field_name: Column(deleted_field_name, deleted_field_type)}

if generate_delete_method:
Expand Down
95 changes: 32 additions & 63 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,52 @@
import datetime
import random
import os

import pytest
from sqlalchemy import create_engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, sessionmaker

from sqlalchemy_easy_softdelete.handler.rewriter import SoftDeleteQueryRewriter
from tests.model import SDChild, SDDerivedRequest, SDParent, TestModelBase
from tests.model import TestModelBase
from tests.seed_data import generate_parent_child_object_hierarchy, generate_table_with_inheritance_obj

test_db_url = 'sqlite://' # use in-memory database for tests
test_db_url = os.environ.get("TEST_CONNECTION_STRING", "sqlite://")


@pytest.fixture(scope="function")
def session_factory():
engine = create_engine(test_db_url)
TestModelBase.metadata.create_all(engine)

yield sessionmaker(bind=engine)

# SQLite in-memory db is deleted when its connection is closed.
# https://www.sqlite.org/inmemorydb.html
engine.dispose()


@pytest.fixture(scope="function")
def session(session_factory) -> Session:
return session_factory()


def generate_parent_child_object_hierarchy(
s: Session, parent_id: int, min_children: int = 1, max_children: int = 3, parent_deleted: bool = False
):
# Fix a seed in the RNG for deterministic outputs
random.seed(parent_id)

# Generate the Parent
deleted_at = datetime.datetime.utcnow() if parent_deleted else None
new_parent = SDParent(id=parent_id, deleted_at=deleted_at)
s.add(new_parent)
s.flush()

active_children = random.randint(min_children, max_children)
@pytest.fixture
def db_engine() -> Engine:
return create_engine(test_db_url)

# Add some active children
for active_id in range(active_children):
new_child = SDChild(id=parent_id * 1000 + active_id, parent=new_parent)
s.add(new_child)
s.flush()

# Add some soft-deleted children
for inactive_id in range(random.randint(min_children, max_children)):
new_soft_deleted_child = SDChild(
id=parent_id * 1000 + active_children + inactive_id,
parent=new_parent,
deleted_at=datetime.datetime.utcnow(),
)
s.add(new_soft_deleted_child)
s.flush()
@pytest.fixture
def db_connection(db_engine) -> Connection:
connection = db_engine.connect()

s.commit()
# start a transaction
transaction = connection.begin()

try:
yield connection
finally:
transaction.rollback()
connection.close()

def generate_table_with_inheritance_obj(s: Session, obj_id: int, deleted: bool = False):
deleted_at = datetime.datetime.utcnow() if deleted else None
new_parent = SDDerivedRequest(id=obj_id, deleted_at=deleted_at)
s.add(new_parent)
s.commit()

@pytest.fixture
def db_session(db_connection) -> Session:
TestModelBase.metadata.create_all(db_connection)
return sessionmaker(autocommit=False, autoflush=False, bind=db_connection)()

@pytest.fixture(scope="function")
def seeded_session(session) -> Session:
generate_parent_child_object_hierarchy(session, 0)
generate_parent_child_object_hierarchy(session, 1)
generate_parent_child_object_hierarchy(session, 2, parent_deleted=True)

generate_table_with_inheritance_obj(session, 0, deleted=False)
generate_table_with_inheritance_obj(session, 1, deleted=False)
generate_table_with_inheritance_obj(session, 2, deleted=True)
return session
@pytest.fixture
def seeded_session(db_session) -> Session:
generate_parent_child_object_hierarchy(db_session, 1000)
generate_parent_child_object_hierarchy(db_session, 1001)
generate_parent_child_object_hierarchy(db_session, 1002, parent_deleted=True)

generate_table_with_inheritance_obj(db_session, 1000, deleted=False)
generate_table_with_inheritance_obj(db_session, 1001, deleted=False)
generate_table_with_inheritance_obj(db_session, 1002, deleted=True)
return db_session


@pytest.fixture
Expand Down
Loading

0 comments on commit e94e2f3

Please sign in to comment.