Skip to content

Commit 0be0269

Browse files
authored
Merge pull request #4 from TechPrismatica/copilot/fix-70b4f7b5-03c5-485f-9930-f3baeb6f2a0a
Add return_keys support to upsert method and implement safe execution pattern for all CRUD operations
2 parents 9cbcd29 + 3488da7 commit 0be0269

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

sql_db_utils/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.3.0"

sql_db_utils/asyncio/sql_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N
6262
table = table if table is not None else self.table
6363
return_keys = return_keys or []
6464
try:
65-
insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys))
65+
insert_stmt = (
66+
insert(table)
67+
.values(data)
68+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
69+
)
6670
return_values = await self.session.execute(insert_stmt)
6771
await self.session.commit()
6872
if return_keys:
@@ -93,7 +97,7 @@ async def update_with_where(
9397
update(table)
9498
.values(data)
9599
.where(*where_conditions)
96-
.returning(*(getattr(table.c, key) for key in return_keys))
100+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
97101
)
98102
return_values = await self.session.execute(update_stmt)
99103
await self.session.commit()
@@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
115119
return_keys = return_keys or []
116120
try:
117121
return_values = await self.session.execute(
118-
update(table).returning(*(getattr(table.c, key) for key in return_keys)), data
122+
update(table).returning(
123+
*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)
124+
),
125+
data,
119126
)
120127
await self.session.commit()
121128
if return_keys:
@@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
124131
logger.error(f"Error occurred while updating: {e}", exc_info=True)
125132
raise e
126133

127-
async def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None):
134+
async def upsert(
135+
self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None
136+
):
128137
"""
129138
Inserts or updates a row in the database.
130139
131140
Args:
132141
insert_json (dict): A dictionary containing the data to be inserted or updated.
133142
primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
143+
return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
134144
table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.
145+
146+
Returns:
147+
A list of dictionaries containing the upserted data if return_keys is provided.
135148
"""
136149
table = table if table is not None else self.table
150+
return_keys = return_keys or []
137151
try:
138152
insert_statement = (
139153
postgres_insert(table)
140154
.values(**insert_json)
141155
.on_conflict_do_update(index_elements=primary_keys, set_=insert_json)
156+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
142157
)
143-
await self.session.execute(insert_statement)
158+
return_values = await self.session.execute(insert_statement)
144159
await self.session.commit()
160+
if return_keys:
161+
return jsonable_encoder(return_values.mappings().all())
145162
except Exception as e:
146163
logger.error(f"Error while upserting the record {e}", exc_info=True)
147164
raise e

sql_db_utils/sql_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
5252
table = table if table is not None else self.table
5353
return_keys = return_keys or []
5454
try:
55-
insert_stmt = insert(table).values(data).returning(*(getattr(table.c, key) for key in return_keys))
55+
insert_stmt = (
56+
insert(table)
57+
.values(data)
58+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
59+
)
5660
return_values = self.session.execute(insert_stmt)
5761
self.session.commit()
5862
if return_keys:
@@ -83,7 +87,7 @@ def update_with_where(
8387
update(table)
8488
.values(data)
8589
.where(*where_conditions)
86-
.returning(*(getattr(table.c, key) for key in return_keys))
90+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
8791
)
8892
return_values = self.session.execute(update_stmt)
8993
self.session.commit()
@@ -105,7 +109,10 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
105109
return_keys = return_keys or []
106110
try:
107111
return_values = self.session.execute(
108-
update(table).returning(*(getattr(table.c, key) for key in return_keys)), data
112+
update(table).returning(
113+
*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys)
114+
),
115+
data,
109116
)
110117
self.session.commit()
111118
if return_keys:
@@ -114,24 +121,34 @@ def update(self, data: Union[dict, list[dict]], return_keys: List[str] = None, t
114121
logger.error(f"Error occurred while updating: {e}", exc_info=True)
115122
raise e
116123

117-
def upsert(self, insert_json: dict, primary_keys: List[str] = None, table: TableType = None):
124+
def upsert(
125+
self, insert_json: dict, primary_keys: List[str] = None, return_keys: List[str] = None, table: TableType = None
126+
):
118127
"""
119128
Inserts or updates a row in the database.
120129
121130
Args:
122131
insert_json (dict): A dictionary containing the data to be inserted or updated.
123132
primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
133+
return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
124134
table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.
135+
136+
Returns:
137+
A list of dictionaries containing the upserted data if return_keys is provided.
125138
"""
126139
table = table if table is not None else self.table
140+
return_keys = return_keys or []
127141
try:
128142
insert_statement = (
129143
postgres_insert(table)
130144
.values(**insert_json)
131145
.on_conflict_do_update(index_elements=primary_keys, set_=insert_json)
146+
.returning(*(getattr(table.c if isinstance(table, Table) else table, key) for key in return_keys))
132147
)
133-
self.session.execute(insert_statement)
148+
return_values = self.session.execute(insert_statement)
134149
self.session.commit()
150+
if return_keys:
151+
return jsonable_encoder(return_values.mappings().all())
135152
except Exception as e:
136153
logger.error(f"Error while upserting the record {e}", exc_info=True)
137154
raise e

0 commit comments

Comments
 (0)