@@ -223,6 +223,43 @@ def log_check_not_passed(self) -> None:
223223 self ._log_object_does_not_exist_message (name )
224224
225225
226+ class CreateForeignKey (DDLAlterOperation ):
227+ """Wraps alembic's create_foreign_key directive."""
228+
229+ def __init__ (
230+ self ,
231+ foreign_key_name : Optional [str ],
232+ table_name : str ,
233+ referent_table : str ,
234+ local_cols : List [str ],
235+ remote_cols : List [str ],
236+ ** kw : Any ,
237+ ) -> None :
238+ super ().__init__ (table_name )
239+ self .foreign_key_name = foreign_key_name
240+ self .referent_table = referent_table
241+ self .local_cols = local_cols
242+ self .remote_cols = remote_cols
243+ self .kw = kw
244+
245+ def batch_execute (self , batch_op ) -> None :
246+ batch_op .create_foreign_key (
247+ self .foreign_key_name , self .referent_table , self .local_cols , self .remote_cols , ** self .kw
248+ )
249+
250+ def non_batch_execute (self ) -> None :
251+ op .create_foreign_key (
252+ self .foreign_key_name , self .table_name , self .referent_table , self .local_cols , self .remote_cols , ** self .kw
253+ )
254+
255+ def pre_execute_check (self ) -> bool :
256+ return not foreign_key_exists (self .foreign_key_name , self .table_name , False )
257+
258+ def log_check_not_passed (self ) -> None :
259+ name = _table_object_description (self .foreign_key_name , self .table_name )
260+ self ._log_object_exists_message (name )
261+
262+
226263class CreateUniqueConstraint (DDLAlterOperation ):
227264 """Wraps alembic's create_unique_constraint directive."""
228265
@@ -294,6 +331,17 @@ def drop_index(index_name, table_name) -> None:
294331 DropIndex (index_name , table_name ).run ()
295332
296333
334+ def create_foreign_key (
335+ foreign_key_name : Optional [str ],
336+ table_name : str ,
337+ referent_table : str ,
338+ local_cols : List [str ],
339+ remote_cols : List [str ],
340+ ** kw : Any ,
341+ ) -> None :
342+ CreateForeignKey (foreign_key_name , table_name , referent_table , local_cols , remote_cols , ** kw ).run ()
343+
344+
297345def create_unique_constraint (constraint_name : str , table_name : str , columns : List [str ]) -> None :
298346 CreateUniqueConstraint (constraint_name , table_name , columns ).run ()
299347
@@ -328,6 +376,15 @@ def index_exists(index_name: str, table_name: str, default: bool) -> bool:
328376 return any (index ["name" ] == index_name for index in indexes )
329377
330378
379+ def foreign_key_exists (constraint_name : str , table_name : str , default : bool ) -> bool :
380+ """Check if unique constraint exists. If running in offline mode, return default."""
381+ if context .is_offline_mode ():
382+ _log_offline_mode_message (foreign_key_exists .__name__ , default )
383+ return default
384+ constraints = _inspector ().get_foreign_keys (table_name )
385+ return any (c ["name" ] == constraint_name for c in constraints )
386+
387+
331388def unique_constraint_exists (constraint_name : str , table_name : str , default : bool ) -> bool :
332389 """Check if unique constraint exists. If running in offline mode, return default."""
333390 if context .is_offline_mode ():
0 commit comments