Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Operation Type support #215

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/schemas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Deleting a schema

Any schema can be dropped, including ones not created by :class:`~psqlextra.schema.PostgresSchema`.

The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` erorr will be raised if you attempt to drop the ``public`` schema.
The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` error will be raised if you attempt to drop the ``public`` schema.

.. warning::

Expand Down
63 changes: 45 additions & 18 deletions psqlextra/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
import sys

from collections.abc import Iterable
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import django

from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.db.models import Expression, Model, Q
from django.db.models import Expression, Field, Model, Q
from django.db.models.fields.related import RelatedField
from django.db.models.sql import compiler as django_compiler
from django.db.utils import ProgrammingError

from .expressions import HStoreValue
from .types import ConflictAction
from .types import ConflictAction, UpsertOperation


def append_caller_to_sql(sql):
def append_caller_to_sql(sql) -> str:
"""Append the caller to SQL queries.

Adds the calling file and function as an SQL comment to each query.
Expand Down Expand Up @@ -162,26 +162,39 @@ def as_sql(self, *args, **kwargs):
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""

RETURNING_OPERATION_TYPE_CLAUSE = (
f"CASE WHEN xmax::text::int > 0 "
f"THEN '{UpsertOperation.UPDATE.value}' "
f"ELSE '{UpsertOperation.INSERT.value}' END"
)
RETURNING_OPERATION_TYPE_FIELD = "_operation_type"

def __init__(self, *args, **kwargs):
"""Initializes a new instance of
:see:PostgresInsertOnConflictCompiler."""
super().__init__(*args, **kwargs)
self.qn = self.connection.ops.quote_name

def as_sql(self, return_id=False, *args, **kwargs):
def as_sql(
self,
return_id=False,
return_operation_type=False,
*args,
**kwargs,
):
"""Builds the SQL INSERT statement."""
queries = [
self._rewrite_insert(sql, params, return_id)
self._rewrite_insert(sql, params, return_id, return_operation_type)
for sql, params in super().as_sql(*args, **kwargs)
]

return queries

def execute_sql(self, return_id=False):
def execute_sql(self, return_id=False, return_operation_type=False):
# execute all the generate queries
with self.connection.cursor() as cursor:
rows = []
for sql, params in self.as_sql(return_id):
for sql, params in self.as_sql(return_id, return_operation_type):
cursor.execute(sql, params)
try:
rows.extend(cursor.fetchall())
Expand All @@ -199,7 +212,9 @@ def execute_sql(self, return_id=False):
for row in rows
]

def _rewrite_insert(self, sql, params, return_id=False):
def _rewrite_insert(
self, sql, params, return_id=False, return_operation_type=False
):
"""Rewrites a formed SQL INSERT query to include the ON CONFLICT
clause.

Expand All @@ -221,16 +236,27 @@ def _rewrite_insert(self, sql, params, return_id=False):
returning = (
self.qn(self.query.model._meta.pk.attname) if return_id else "*"
)
# Return metadata about the row, so we can tell if it was inserted or
# updated by checking the `xmax` Postgres system column.
if return_operation_type:
returning += f", ({self.RETURNING_OPERATION_TYPE_CLAUSE}) AS {self.RETURNING_OPERATION_TYPE_FIELD}"

(sql, params) = self._rewrite_insert_on_conflict(
sql, params, self.query.conflict_action.value, returning
sql,
params,
self.query.conflict_action.value,
returning,
)

return append_caller_to_sql(sql), params

def _rewrite_insert_on_conflict(
self, sql, params, conflict_action: ConflictAction, returning
):
self,
sql: str,
params: list,
conflict_action: ConflictAction,
returning: str,
) -> Tuple[str, list]:
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
clause."""

Expand All @@ -256,7 +282,7 @@ def _rewrite_insert_on_conflict(

rewritten_sql += f" DO {conflict_action}"

if conflict_action == "UPDATE":
if conflict_action == ConflictAction.UPDATE.value:
rewritten_sql += f" SET {update_columns}"

if update_condition:
Expand Down Expand Up @@ -353,7 +379,7 @@ def _build_conflict_target_by_index(self):
stmt = matching_index.create_sql(self.query.model, schema_editor)
return "(%s)" % stmt.parts["columns"]

def _get_model_field(self, name: str):
def _get_model_field(self, name: str) -> Optional[Field]:
"""Gets the field on a model with the specified name.

Arguments:
Expand Down Expand Up @@ -384,7 +410,7 @@ def _get_model_field(self, name: str):

return None

def _format_field_name(self, field_name) -> str:
def _format_field_name(self, field_name):
"""Formats a field's name for usage in SQL.

Arguments:
Expand All @@ -399,7 +425,7 @@ def _format_field_name(self, field_name) -> str:
field = self._get_model_field(field_name)
return self.qn(field.column)

def _format_field_value(self, field_name) -> str:
def _format_field_value(self, field_name):
"""Formats a field's value for usage in SQL.

Arguments:
Expand Down Expand Up @@ -432,7 +458,8 @@ def _format_field_value(self, field_name) -> str:
)

def _compile_expression(
self, expression: Union[Expression, Q, str]
self,
expression: Union[Expression, Q, str],
) -> Tuple[str, Union[tuple, list]]:
"""Compiles an expression, Q object or raw SQL string into SQL and
tuple of parameters."""
Expand All @@ -452,7 +479,7 @@ def _compile_expression(

return expression, tuple()

def _assert_valid_field(self, field_name: str):
def _assert_valid_field(self, field_name: str) -> None:
"""Asserts that a field with the specified name exists on the model and
raises :see:SuspiciousOperation if it does not."""

Expand Down
40 changes: 33 additions & 7 deletions psqlextra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def bulk_insert(
rows: Iterable[dict],
return_model: bool = False,
using: Optional[str] = None,
return_operation_type: bool = False,
):
"""Creates multiple new records in the database.

Expand All @@ -158,6 +159,13 @@ def bulk_insert(
Optional name of the database connection to use for
this query.

return_operation_type (default: False):
If the operation type should be returned for each row.
This is only supported when return_model is False.
The operation_type is either 'INSERT' or 'UPDATE' and
the value will be contained in the '_operation_type' key
of the returned dict.

Returns:
A list of either the dicts of the rows inserted, including the pk or
the models of the rows inserted with defaults for any fields not specified
Expand Down Expand Up @@ -195,7 +203,10 @@ def is_empty(r):
deduped_rows.append(row)

compiler = self._build_insert_compiler(deduped_rows, using=using)
objs = compiler.execute_sql(return_id=not return_model)
objs = compiler.execute_sql(
return_id=not return_model,
return_operation_type=return_operation_type and not return_model,
)
if return_model:
return [
self._create_model_instance(dict(row, **obj), compiler.using)
Expand Down Expand Up @@ -261,7 +272,9 @@ def insert_and_get(self, using: Optional[str] = None, **fields):
return super().create(**fields)

compiler = self._build_insert_compiler([fields], using=using)
rows = compiler.execute_sql(return_id=False)
rows = compiler.execute_sql(
return_id=False, return_operation_type=False
)

if not rows:
return None
Expand Down Expand Up @@ -293,7 +306,7 @@ def upsert(
index_predicate: Optional[Union[Expression, Q, str]] = None,
using: Optional[str] = None,
update_condition: Optional[Union[Expression, Q, str]] = None,
) -> int:
) -> Optional[int]:
"""Creates a new record or updates the existing one with the specified
data.

Expand Down Expand Up @@ -336,7 +349,7 @@ def upsert_and_get(
index_predicate: Optional[Union[Expression, Q, str]] = None,
using: Optional[str] = None,
update_condition: Optional[Union[Expression, Q, str]] = None,
):
) -> Optional[TModel]:
"""Creates a new record or updates the existing one with the specified
data and then gets the row.

Expand Down Expand Up @@ -381,6 +394,7 @@ def bulk_upsert(
return_model: bool = False,
using: Optional[str] = None,
update_condition: Optional[Union[Expression, Q, str]] = None,
return_operation_type: bool = False,
):
"""Creates a set of new records or updates the existing ones with the
specified data.
Expand All @@ -407,6 +421,13 @@ def bulk_upsert(
update_condition:
Only update if this SQL expression evaluates to true.

return_operation_type (default: False):
If the operation type should be returned for each row.
This is only supported when return_model is False.
The operation_type is either 'INSERT' or 'UPDATE' and
the value will be contained in the '_operation_type' key
of the returned dict.

Returns:
A list of either the dicts of the rows upserted, including the pk or
the models of the rows upserted
Expand All @@ -418,15 +439,20 @@ def bulk_upsert(
index_predicate=index_predicate,
update_condition=update_condition,
)
return self.bulk_insert(rows, return_model, using=using)
return self.bulk_insert(
rows,
return_model,
using=using,
return_operation_type=return_operation_type,
)

def _create_model_instance(
self, field_values: dict, using: str, apply_converters: bool = True
):
"""Creates a new instance of the model with the specified field.

Use this after the row was inserted into the database. The new
instance will marked as "saved".
Use this after the row was inserted/updated into the database.
The new instance will be marked as "saved".
"""

converted_field_values = field_values.copy()
Expand Down
7 changes: 7 additions & 0 deletions psqlextra/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def all(cls) -> List["ConflictAction"]:
return [choice for choice in cls]


class UpsertOperation(StrEnum):
"""Possible operations to take on an upsert."""

INSERT = "INSERT"
UPDATE = "UPDATE"


class PostgresPartitioningMethod(StrEnum):
"""Methods of partitioning supported by PostgreSQL 11.x native support for
table partitioning."""
Expand Down
32 changes: 32 additions & 0 deletions tests/test_on_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from psqlextra.fields import HStoreField
from psqlextra.models import PostgresModel
from psqlextra.query import ConflictAction
from psqlextra.types import UpsertOperation

from .fake_model import get_fake_model

Expand Down Expand Up @@ -397,6 +398,37 @@ def test_bulk_return():
assert obj["id"] == index


def test_bulk_return_with_operation_type():
"""Tests if the _operation_type is properly returned from 'bulk_insert'."""

model = get_fake_model(
{
"id": models.BigAutoField(primary_key=True),
"name": models.CharField(max_length=255, unique=True),
}
)

rows = [dict(name="John Smith"), dict(name="Jane Doe")]

objs = model.objects.on_conflict(
["name"], ConflictAction.UPDATE
).bulk_insert(rows, return_operation_type=True)

for index, obj in enumerate(objs, 1):
assert obj["id"] == index
assert obj["_operation_type"] == UpsertOperation.INSERT.value

# Add objects again, update should return the same ids
# as we're just updating.
objs = model.objects.on_conflict(
["name"], ConflictAction.UPDATE
).bulk_insert(rows, return_operation_type=True)

for index, obj in enumerate(objs, 1):
assert obj["id"] == index
assert obj["_operation_type"] == UpsertOperation.UPDATE.value


@pytest.mark.parametrize("conflict_action", ConflictAction.all())
def test_bulk_return_models(conflict_action):
"""Tests whether models are returned instead of dictionaries when
Expand Down
39 changes: 39 additions & 0 deletions tests/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from psqlextra.expressions import ExcludedCol
from psqlextra.fields import HStoreField
from psqlextra.query import ConflictAction
from psqlextra.types import UpsertOperation

from .fake_model import get_fake_model

Expand Down Expand Up @@ -259,6 +260,44 @@ def test_upsert_bulk_no_rows():
)


def test_upsert_bulk_returns_operation_type():
"""Tests whether bulk_upsert works properly with the return_operation_type
flag."""

model = get_fake_model(
{
"first_name": models.CharField(
max_length=255, null=True, unique=True
),
"last_name": models.CharField(max_length=255, null=True),
}
)

rows = model.objects.bulk_upsert(
conflict_target=["first_name"],
rows=[
dict(first_name="Swen", last_name="Kooij"),
dict(first_name="Henk", last_name="Test"),
],
return_operation_type=True,
)

for row in rows:
assert row["_operation_type"] == UpsertOperation.INSERT.value

rows = model.objects.bulk_upsert(
conflict_target=["first_name"],
rows=[
dict(first_name="Swen", last_name="Test"),
dict(first_name="Henk", last_name="Kooij"),
],
return_operation_type=True,
)

for row in rows:
assert row["_operation_type"] == UpsertOperation.UPDATE.value


def test_bulk_upsert_return_models():
"""Tests whether models are returned instead of dictionaries when
specifying the return_model=True argument."""
Expand Down