@@ -62,7 +62,11 @@ async def insert(self, data: Union[dict, list[dict]], return_keys: List[str] = N
62
62
table = table if table is not None else self .table
63
63
return_keys = return_keys or []
64
64
try :
65
- insert_stmt = insert (table ).values (data ).returning (* (getattr (table .c , key ) for key in return_keys ))
65
+ insert_stmt = (
66
+ insert (table )
67
+ .values (data )
68
+ .returning (* (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys ))
69
+ )
66
70
return_values = await self .session .execute (insert_stmt )
67
71
await self .session .commit ()
68
72
if return_keys :
@@ -93,7 +97,7 @@ async def update_with_where(
93
97
update (table )
94
98
.values (data )
95
99
.where (* where_conditions )
96
- .returning (* (getattr (table .c , key ) for key in return_keys ))
100
+ .returning (* (getattr (table .c if isinstance ( table , Table ) else table , key ) for key in return_keys ))
97
101
)
98
102
return_values = await self .session .execute (update_stmt )
99
103
await self .session .commit ()
@@ -115,7 +119,10 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
115
119
return_keys = return_keys or []
116
120
try :
117
121
return_values = await self .session .execute (
118
- update (table ).returning (* (getattr (table .c , key ) for key in return_keys )), data
122
+ update (table ).returning (
123
+ * (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys )
124
+ ),
125
+ data ,
119
126
)
120
127
await self .session .commit ()
121
128
if return_keys :
@@ -124,24 +131,34 @@ async def update(self, data: Union[dict, list[dict]], return_keys: List[str] = N
124
131
logger .error (f"Error occurred while updating: { e } " , exc_info = True )
125
132
raise e
126
133
127
- async def upsert (self , insert_json : dict , primary_keys : List [str ] = None , table : TableType = None ):
134
+ async def upsert (
135
+ self , insert_json : dict , primary_keys : List [str ] = None , return_keys : List [str ] = None , table : TableType = None
136
+ ):
128
137
"""
129
138
Inserts or updates a row in the database.
130
139
131
140
Args:
132
141
insert_json (dict): A dictionary containing the data to be inserted or updated.
133
142
primary_keys (List[str], optional): A list of primary key column names. Defaults to None.
143
+ return_keys (List[str], optional): A list of column names to return after the upsert. Defaults to None.
134
144
table (TableType, optional): The SQLAlchemy declarative base object. Defaults to None.
145
+
146
+ Returns:
147
+ A list of dictionaries containing the upserted data if return_keys is provided.
135
148
"""
136
149
table = table if table is not None else self .table
150
+ return_keys = return_keys or []
137
151
try :
138
152
insert_statement = (
139
153
postgres_insert (table )
140
154
.values (** insert_json )
141
155
.on_conflict_do_update (index_elements = primary_keys , set_ = insert_json )
156
+ .returning (* (getattr (table .c if isinstance (table , Table ) else table , key ) for key in return_keys ))
142
157
)
143
- await self .session .execute (insert_statement )
158
+ return_values = await self .session .execute (insert_statement )
144
159
await self .session .commit ()
160
+ if return_keys :
161
+ return jsonable_encoder (return_values .mappings ().all ())
145
162
except Exception as e :
146
163
logger .error (f"Error while upserting the record { e } " , exc_info = True )
147
164
raise e
0 commit comments