Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql_db_utils/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.3.0"
27 changes: 22 additions & 5 deletions sql_db_utils/asyncio/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N
table = table if table is not None else self.table
return_keys = return_keys or []
try:
insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys))
insert_stmt = (
insert(table)
.values(data)
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
return_values = await self.session.execute(insert_stmt)
await self.session.commit()
if return_keys:
Expand Down Expand Up @@ -93,7 +97,7 @@ async def update_with_where(
update(table)
.values(data)
.where(*where_conditions)
.returning(*(getattr(table.c, key) for key in return_keys))
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
return_values = await self.session.execute(update_stmt)
await self.session.commit()
Expand All @@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
return_keys = return_keys or []
try:
return_values = await self.session.execute(
update(table).returning(*(getattr(table.c, key) for key in return_keys)), data
update(table).returning(
*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)
),
data,
)
await self.session.commit()
if return_keys:
Expand All @@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
logger.error(f"Error occurred while updating: {e}", exc_info=True)
raise e

async def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None):
async def upsert(
self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None
):
"""
Inserts or updates a row in the database.

Args:
insert_json (dict): A dictionary containing the data to be inserted or updated.
primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.

Returns:
A list of dictionaries containing the upserted data if return_keys is provided.
"""
table = table if table is not None else self.table
return_keys = return_keys or []
try:
insert_statement = (
postgres_insert(table)
.values(**insert_json)
.on_conflict_do_update(index_elements=primary_keys, set_=insert_json)
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
await self.session.execute(insert_statement)
return_values = await self.session.execute(insert_statement)
await self.session.commit()
if return_keys:
return jsonable_encoder(return_values.mappings().all())
except Exception as e:
logger.error(f"Error while upserting the record {e}", exc_info=True)
raise e
Expand Down
27 changes: 22 additions & 5 deletions sql_db_utils/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
table = table if table is not None else self.table
return_keys = return_keys or []
try:
insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys))
insert_stmt = (
insert(table)
.values(data)
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
return_values = self.session.execute(insert_stmt)
self.session.commit()
if return_keys:
Expand Down Expand Up @@ -83,7 +87,7 @@ def update_with_where(
update(table)
.values(data)
.where(*where_conditions)
.returning(*(getattr(table.c, key) for key in return_keys))
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
return_values = self.session.execute(update_stmt)
self.session.commit()
Expand All @@ -105,7 +109,10 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
return_keys = return_keys or []
try:
return_values = self.session.execute(
update(table).returning(*(getattr(table.c, key) for key in return_keys)), data
update(table).returning(
*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)
),
data,
)
self.session.commit()
if return_keys:
Expand All @@ -114,24 +121,34 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
logger.error(f"Error occurred while updating: {e}", exc_info=True)
raise e

def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None):
def upsert(
self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None
):
"""
Inserts or updates a row in the database.

Args:
insert_json (dict): A dictionary containing the data to be inserted or updated.
primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.

Returns:
A list of dictionaries containing the upserted data if return_keys is provided.
"""
table = table if table is not None else self.table
return_keys = return_keys or []
try:
insert_statement = (
postgres_insert(table)
.values(**insert_json)
.on_conflict_do_update(index_elements=primary_keys, set_=insert_json)
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
)
self.session.execute(insert_statement)
return_values = self.session.execute(insert_statement)
self.session.commit()
if return_keys:
return jsonable_encoder(return_values.mappings().all())
except Exception as e:
logger.error(f"Error while upserting the record {e}", exc_info=True)
raise e
Expand Down