diff --git a/deploy/compose/docker-compose.yml b/deploy/compose/docker-compose.yml index c177f8e..290d060 100644 --- a/deploy/compose/docker-compose.yml +++ b/deploy/compose/docker-compose.yml @@ -12,7 +12,8 @@ services: command: ["python", "-m", "sbosc.controller.main"] restart: always depends_on: - - redis + redis: + condition: service_healthy eventhandler: <<: *component-base @@ -49,6 +50,11 @@ services: volumes: - redis-data:/data - ./redis.conf:/usr/local/etc/redis/redis.conf + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 volumes: redis-data: diff --git a/doc/config.md b/doc/config.md index 2d4adc7..f0e9548 100644 --- a/doc/config.md +++ b/doc/config.md @@ -9,6 +9,11 @@ If you set this parameter to `True`, SB-OSC will skip the bulk import stage and ### disable_apply_dml_events If you set this parameter to `True`, SB-OSC will pause before `apply_dml_events` stage. This is useful when you have additional steps to perform manually before applying DML events. +### disable_eventhandler +If you set this parameter to `True`, SB-OSC will disable eventhandler, which means it will not process binlog events. Only bulk import will be performed. + +After `bulk_import_validation` stage it will move directly to `done` stage. So, `add_index` stage will be skipped since `apply_dml_events` stage will not be executed. + ## Chunk ### max_chunk_count & min_chunk_size @@ -34,3 +39,6 @@ These parameters control insert throughput of SB-OSC. `batch_size` and `thread_c `LIMIT batch_size` is applied to the next query to prevent from inserting too many rows at once. +**Note:** This option utilizes cursor.lastrowid to the `last_inserted_pk` which only returns non-zero value when table has **AUTO_INCREMENT** column. +([MySQL Document](https://dev.mysql.com/doc/connector-python/en/connector-python-api-mysqlcursor-lastrowid.html)) + diff --git a/doc/operation-class.md b/doc/operation-class.md index aa61f86..e9522a7 100644 --- a/doc/operation-class.md +++ b/doc/operation-class.md @@ -27,16 +27,17 @@ class MessageRetentionOperation(BaseOperation): INSERT INTO {self.source_db}.{self.destination_table}({self.source_columns}) SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} AS source - WHERE source.id BETWEEN {start_pk} AND {end_pk} + WHERE source.{self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL 30 DAY) """ def _get_not_imported_pks_query(self, start_pk, end_pk): return f''' - SELECT source.id FROM {self.source_db}.{self.source_table} AS source - LEFT JOIN {self.source_db}.{self.destination_table} AS dest ON source.id = dest.id - WHERE source.id BETWEEN {start_pk} AND {end_pk} + SELECT source.{self.pk_column} FROM {self.source_db}.{self.source_table} AS source + LEFT JOIN {self.source_db}.{self.destination_table} AS dest + ON source.{self.pk_column} = dest.{self.pk_column} + WHERE source.{self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL 30 DAY) - AND dest.id IS NULL + AND dest.{self.pk_column} IS NULL ''' ``` @@ -48,20 +49,20 @@ class CrossClusterMessageRetentionOperation(CrossClusterBaseOperation): def _select_batch_query(self, start_pk, end_pk): return f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id BETWEEN {start_pk} AND {end_pk} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL 30 DAY) ''' def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): source_cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} - WHERE id BETWEEN {start_pk} AND {end_pk} + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL 30 DAY) ''') source_pks = [row[0] for row in source_cursor.fetchall()] dest_cursor.execute(f''' - SELECT id FROM {self.destination_db}.{self.destination_table} - WHERE id BETWEEN {start_pk} AND {end_pk} + SELECT {self.pk_column} FROM {self.destination_db}.{self.destination_table} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL 30 DAY) ''') dest_pks = [row[0] for row in dest_cursor.fetchall()] @@ -89,7 +90,7 @@ class MessageRetentionOperation(BaseOperation): INSERT INTO {self.source_db}.{self.destination_table}({self.source_columns}) SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} AS source - WHERE source.id BETWEEN {start_pk} AND {end_pk} + WHERE source.{self.pk_column} BETWEEN {start_pk} AND {end_pk} AND source.ts > DATE_SUB(NOW(), INTERVAL {self.operation_config.retention_days} DAY) """ ``` diff --git a/src/config/config.py b/src/config/config.py index 81d8521..4f97233 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -84,6 +84,7 @@ class Config: USE_BATCH_SIZE_MULTIPLIER = False # EventHandler config + DISABLE_EVENTHANDLER = False EVENTHANDLER_THREAD_COUNT = 4 EVENTHANDLER_THREAD_TIMEOUT_IN_SECONDS = 300 INIT_BINLOG_FILE: str = None diff --git a/src/modules/db.py b/src/modules/db.py index 491f5e3..ecdbd30 100644 --- a/src/modules/db.py +++ b/src/modules/db.py @@ -74,7 +74,10 @@ def cursor(self, cursorclass=None): def ping(self): if not self._conn: self._conn = self.connect() - self._conn.ping() + try: + self._conn.ping() + except OperationalError: + self._conn = self.connect() def close(self): if self._conn: @@ -104,10 +107,7 @@ def get_connection(self): yield conn - try: - conn.ping() - except OperationalError: - conn = Connection(self.endpoint) + conn.ping() if self.free_connections.full(): raise Exception("Connection pool full") else: diff --git a/src/modules/redis/schema.py b/src/modules/redis/schema.py index c5fb30d..b590815 100644 --- a/src/modules/redis/schema.py +++ b/src/modules/redis/schema.py @@ -28,7 +28,8 @@ class Metadata(Hash): destination_db: str destination_table: str source_columns: str - max_id: int + pk_column: str + max_pk: int start_datetime: datetime diff --git a/src/sbosc/controller/controller.py b/src/sbosc/controller/controller.py index 6141214..19145a8 100644 --- a/src/sbosc/controller/controller.py +++ b/src/sbosc/controller/controller.py @@ -52,7 +52,6 @@ def start(self): if action: action() - # TODO: Add Redis data validation if needed time.sleep(self.interval) # Close db connection @@ -63,14 +62,14 @@ def create_bulk_import_chunks(self): self.redis_data.remove_all_chunks() metadata = self.redis_data.metadata - max_id = metadata.max_id + max_pk = metadata.max_pk # chunk_count is determined by min_chunk_size and max_chunk_count # Each chunk will have min_chunk_size rows and the number of chunks should not exceed max_chunk_count min_chunk_size = config.MIN_CHUNK_SIZE max_chunk_count = config.MAX_CHUNK_COUNT # Number of chunks means max number of worker threads - chunk_count = min(max_id // min_chunk_size, max_chunk_count) - chunk_size = max_id // chunk_count + chunk_count = min(max_pk // min_chunk_size, max_chunk_count) + chunk_size = max_pk // chunk_count # Create chunks # Each chunk will have a range of primary key values [start_pk, end_pk] @@ -79,7 +78,7 @@ def create_bulk_import_chunks(self): start_pk = i * chunk_size + 1 end_pk = (i + 1) * chunk_size if i == chunk_count - 1: - end_pk = max_id + end_pk = max_pk chunk_id = f"{self.migration_id}-{i}" chunk_info = self.redis_data.get_chunk_info(chunk_id) @@ -112,7 +111,7 @@ def create_bulk_import_chunks(self): self.redis_data.set_current_stage(Stage.BULK_IMPORT) self.slack.send_message( subtitle="Bulk import started", - message=f"Max id: {max_id}\n" + message=f"Max PK: {max_pk}\n" f"Chunk count: {chunk_count}\n" f"Chunk size: {chunk_size}\n" f"Batch size: {config.MIN_BATCH_SIZE}\n" @@ -166,7 +165,10 @@ def validate_bulk_import(self): self.redis_data.set_current_stage(Stage.BULK_IMPORT_VALIDATION_FAILED) self.slack.send_message(message="Bulk import validation failed", color="danger") else: - self.redis_data.set_current_stage(Stage.APPLY_DML_EVENTS) + if not config.DISABLE_EVENTHANDLER: + self.redis_data.set_current_stage(Stage.APPLY_DML_EVENTS) + else: + self.redis_data.set_current_stage(Stage.DONE) self.slack.send_message(message="Bulk import validation succeeded", color="good") except StopFlagSet: return @@ -213,45 +215,46 @@ def add_index(self): finished_all_creation = False while not self.stop_flag: finished_creation = False - with self.db.cursor() as cursor: - cursor: Cursor + with self.db.cursor(role='reader') as source_cursor: + source_cursor: Cursor index_info = None - cursor.execute(f''' + source_cursor.execute(f''' SELECT index_name FROM {config.SBOSC_DB}.index_creation_status WHERE migration_id = %s AND ended_at IS NULL AND started_at IS NOT NULL ''', (self.migration_id,)) - if cursor.rowcount > 0: - index_names = [row[0] for row in cursor.fetchall()] + if source_cursor.rowcount > 0: + index_names = [row[0] for row in source_cursor.fetchall()] self.slack.send_message( subtitle="Found unfinished index creation", message=f"Indexes: {index_names}", color="warning") while True: if self.stop_flag: return - cursor.execute(f''' - SELECT DISTINCT database_name, table_name, index_name FROM mysql.innodb_index_stats - WHERE database_name = %s AND table_name = %s - AND index_name IN ({','.join(['%s'] * len(index_names))}) - ''', [metadata.destination_db, metadata.destination_table] + index_names) - if cursor.rowcount == len(index_names): - finished_creation = True - break + with self.db.cursor(host='dest', role='reader') as dest_cursor: + dest_cursor.execute(f''' + SELECT DISTINCT database_name, table_name, index_name FROM mysql.innodb_index_stats + WHERE database_name = %s AND table_name = %s + AND index_name IN ({','.join(['%s'] * len(index_names))}) + ''', [metadata.destination_db, metadata.destination_table] + index_names) + if dest_cursor.rowcount == len(index_names): + finished_creation = True + break self.logger.info("Waiting for index creation to finish") time.sleep(60) else: - cursor.execute(f''' + source_cursor.execute(f''' SELECT index_name, index_columns, is_unique FROM {config.SBOSC_DB}.index_creation_status WHERE migration_id = %s AND ended_at IS NULL LIMIT {config.INDEX_CREATED_PER_QUERY} ''', (self.migration_id,)) - if cursor.rowcount == 0: + if source_cursor.rowcount == 0: finished_all_creation = True break - index_info = cursor.fetchall() + index_info = source_cursor.fetchall() index_names = [index_name for index_name, *_ in index_info] if index_info and not finished_creation: @@ -260,30 +263,30 @@ def add_index(self): # update ended_at started_at = datetime.now() - with self.db.cursor() as cursor: - cursor: Cursor - cursor.executemany(f''' + with self.db.cursor() as source_cursor: + source_cursor: Cursor + source_cursor.executemany(f''' UPDATE {config.SBOSC_DB}.index_creation_status SET started_at = %s WHERE migration_id = %s AND index_name = %s ''', [(started_at, self.migration_id, index_name) for index_name in index_names]) # add index - with self.db.cursor(host='dest') as cursor: - cursor: Cursor + with self.db.cursor(host='dest') as dest_cursor: + dest_cursor: Cursor # set session variables if config.INNODB_DDL_BUFFER_SIZE is not None: - cursor.execute(f"SET SESSION innodb_ddl_buffer_size = {config.INNODB_DDL_BUFFER_SIZE}") + dest_cursor.execute(f"SET SESSION innodb_ddl_buffer_size = {config.INNODB_DDL_BUFFER_SIZE}") self.logger.info(f"Set innodb_ddl_buffer_size to {config.INNODB_DDL_BUFFER_SIZE}") if config.INNODB_DDL_THREADS is not None: - cursor.execute(f"SET SESSION innodb_ddl_threads = {config.INNODB_DDL_THREADS}") + dest_cursor.execute(f"SET SESSION innodb_ddl_threads = {config.INNODB_DDL_THREADS}") self.logger.info(f"Set innodb_ddl_threads to {config.INNODB_DDL_THREADS}") if config.INNODB_PARALLEL_READ_THREADS is not None: - cursor.execute( + dest_cursor.execute( f"SET SESSION innodb_parallel_read_threads = {config.INNODB_PARALLEL_READ_THREADS}") self.logger.info(f"Set innodb_parallel_read_threads to {config.INNODB_PARALLEL_READ_THREADS}") - cursor.execute(f''' + dest_cursor.execute(f''' ALTER TABLE {metadata.destination_db}.{metadata.destination_table} {', '.join([ f"ADD{' UNIQUE' if is_unique else ''} INDEX {index_name} ({index_columns})" @@ -296,9 +299,9 @@ def add_index(self): if finished_creation: # update ended_at ended_at = datetime.now() - with self.db.cursor() as cursor: - cursor: Cursor - cursor.executemany(f''' + with self.db.cursor() as source_cursor: + source_cursor: Cursor + source_cursor.executemany(f''' UPDATE {config.SBOSC_DB}.index_creation_status SET ended_at = %s WHERE migration_id = %s AND index_name = %s ''', [(ended_at, self.migration_id, index_name) for index_name in index_names]) @@ -343,7 +346,7 @@ def swap_tables(self): old_source_table = f"{metadata.source_db}.{self.redis_data.old_source_table}" cursor.execute(f"RENAME TABLE {source_table} TO {old_source_table}") after_rename_table_timestamp = time.time() - cursor.execute(f"SELECT MAX(id) FROM {old_source_table}") + cursor.execute(f"SELECT MAX({metadata.pk_column}) FROM {old_source_table}") final_max_id = cursor.fetchone()[0] with self.validator.migration_operation.override_source_table(self.redis_data.old_source_table): diff --git a/src/sbosc/controller/initializer.py b/src/sbosc/controller/initializer.py index 36549a2..4285f02 100644 --- a/src/sbosc/controller/initializer.py +++ b/src/sbosc/controller/initializer.py @@ -174,14 +174,32 @@ def fetch_metadata(self, redis_data): metadata.source_columns = cursor.fetchone()[0] self.logger.info("Saved source column schema to Redis") - # Get max id - cursor.execute("SELECT MAX(id) FROM %s.%s" % (metadata.source_db, metadata.source_table)) - max_id = cursor.fetchone()[0] - metadata.max_id = max_id + # Get pk column + cursor.execute(f''' + SELECT COLUMN_NAME FROM information_schema.COLUMNS + WHERE TABLE_SCHEMA = '{metadata.source_db}' AND TABLE_NAME = '{metadata.source_table}' + AND COLUMN_KEY = 'PRI' AND DATA_TYPE IN ('int', 'bigint') + ''') + if cursor.rowcount == 0: + raise Exception("Integer primary key column not found") + metadata.pk_column = f"`{cursor.fetchone()[0]}`" + self.logger.info("Saved primary key column to Redis") + + # Get max PK + cursor.execute(''' + SELECT MAX(%s) FROM %s.%s + ''' % (metadata.pk_column, metadata.source_db, metadata.source_table)) + max_pk = cursor.fetchone()[0] + if max_pk is None: + raise Exception("No data in source table") + metadata.max_pk = max_pk self.logger.info("Saved total rows to Redis") metadata.start_datetime = datetime.now() - redis_data.set_current_stage(Stage.START_EVENT_HANDLER) + if not config.DISABLE_EVENTHANDLER: + redis_data.set_current_stage(Stage.START_EVENT_HANDLER) + else: + redis_data.set_current_stage(Stage.BULK_IMPORT_CHUNK_CREATION) def init_migration(self): if not self.check_database_setup(): diff --git a/src/sbosc/controller/validator.py b/src/sbosc/controller/validator.py index 0dfb97d..757dc3d 100644 --- a/src/sbosc/controller/validator.py +++ b/src/sbosc/controller/validator.py @@ -69,8 +69,8 @@ def __validate_bulk_import_batch(self, range_queue: Queue, failed_pks): return False except MySQLdb.OperationalError as e: self.__handle_operational_error(e, range_queue, batch_start_pk, batch_end_pk) - source_conn.ping(True) - dest_conn.ping(True) + source_conn.ping() + dest_conn.ping() continue except Empty: self.logger.warning("Range queue is empty") @@ -83,8 +83,8 @@ def bulk_import_validation(self): metadata = self.redis_data.metadata range_queue = Queue() start_pk = 0 - while start_pk <= metadata.max_id: - range_queue.put((start_pk, min(start_pk + self.bulk_import_batch_size, metadata.max_id))) + while start_pk <= metadata.max_pk: + range_queue.put((start_pk, min(start_pk + self.bulk_import_batch_size, metadata.max_pk))) start_pk += self.bulk_import_batch_size + 1 failed_pks = [] @@ -153,13 +153,15 @@ def __execute_apply_dml_events_validation_query( if event_pks: event_pks_str = ','.join([str(pk) for pk in event_pks]) dest_cursor.execute(f''' - SELECT id FROM {metadata.destination_db}.{metadata.destination_table} WHERE id IN ({event_pks_str}) + SELECT {metadata.pk_column} FROM {metadata.destination_db}.{metadata.destination_table} + WHERE {metadata.pk_column} IN ({event_pks_str}) ''') not_deleted_pks = set([row[0] for row in dest_cursor.fetchall()]) if dest_cursor.rowcount > 0: # Check if deleted pks are reinserted source_cursor.execute(f''' - SELECT id FROM {metadata.source_db}.{metadata.source_table} WHERE id IN ({event_pks_str}) + SELECT {metadata.pk_column} FROM {metadata.source_db}.{metadata.source_table} + WHERE {metadata.pk_column} IN ({event_pks_str}) ''') reinserted_pks = set([row[0] for row in source_cursor.fetchall()]) if reinserted_pks: @@ -224,8 +226,8 @@ def __validate_apply_dml_events_batch(self, table, range_queue: Queue, unmatched ) except MySQLdb.OperationalError as e: self.__handle_operational_error(e, range_queue, batch_start_timestamp, batch_end_timestamp) - source_conn.ping(True) - dest_conn.ping(True) + source_conn.ping() + dest_conn.ping() continue def __validate_unmatched_pks(self): diff --git a/src/sbosc/eventhandler/eventhandler.py b/src/sbosc/eventhandler/eventhandler.py index 95ce264..259fe7e 100644 --- a/src/sbosc/eventhandler/eventhandler.py +++ b/src/sbosc/eventhandler/eventhandler.py @@ -1,6 +1,5 @@ import concurrent.futures import time -from queue import Queue, Empty from threading import Thread from MySQLdb.cursors import Cursor, DictCursor @@ -85,7 +84,7 @@ def __init__(self): 'passwd': secret.PASSWORD, } - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.thread_count) + self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.thread_count) self.log_file = None self.log_pos = None @@ -188,7 +187,7 @@ def start(self): self.logger.info('Starting event handler') while not self.stop_flag: current_stage = self.redis_data.current_stage - if Stage.DONE > current_stage >= Stage.START_EVENT_HANDLER: + if Stage.DONE > current_stage >= Stage.START_EVENT_HANDLER and not config.DISABLE_EVENTHANDLER: if self.log_file is None or self.log_pos is None: self.logger.info('Initializing event handler') self.init_event_handler() @@ -258,26 +257,21 @@ def apply_dml_events_pre_validation(self): else: self.redis_data.set_current_stage(Stage.ADD_INDEX) - def parse_binlog_batch(self, thread_id, batch_queue: Queue, done_batch: list): + @staticmethod + def parse_binlog_batch(stream): event_store = EventStore() - while batch_queue.qsize() > 0 and not self.stop_flag: - try: - binlog_file, start_pos = batch_queue.get_nowait() - except Empty: - self.logger.warning('Binlog batch queue is empty') - continue - stream = self.create_binlog_stream(binlog_file, start_pos, thread_id) - for event in stream: - event_store.add_event(event) - if stream.log_file != binlog_file: - break - - done_batch.append((stream.log_file, stream.log_pos)) - stream.close() - return event_store + start_file = stream.log_file + for event in stream: + event_store.add_event(event) + if stream.log_file != start_file: + break + end_file = stream.log_file + end_pos = stream.log_pos + stream.close() + return event_store, (end_file, end_pos) def follow_event_stream(self): - file_queue = Queue() + target_files = [] # Create binlog batch queue with self.db.cursor(DictCursor) as cursor: @@ -293,25 +287,31 @@ def follow_event_stream(self): ] for log_file in binlog_files[:self.thread_count]: start_pos = self.log_pos if log_file == self.log_file else 4 - file_queue.put((log_file, start_pos)) + target_files.append((log_file, start_pos)) # Parse binlog batches threads = [] - done_files = [] - queued_files = file_queue.qsize() event_store = EventStore() result_event_stores = [] + done_files = [] - for i in range(self.thread_count): - threads.append(self.executor.submit(self.parse_binlog_batch, i, file_queue, done_files)) + for thread_id in range(len(target_files)): + binlog_file, start_pos = target_files[thread_id] + stream = self.create_binlog_stream(binlog_file, start_pos, thread_id) + threads.append(self.executor.submit(self.parse_binlog_batch, stream)) done, not_done = concurrent.futures.wait(threads, timeout=self.thread_timeout) if len(not_done) > 0: self.set_stop_flag() raise Exception('Binlog batch parsing timed out') + for thread in threads: - result_event_stores.append(thread.result()) + result_event_store, done_file = thread.result() + result_event_stores.append(result_event_store) + done_files.append(done_file) - if len(done_files) == queued_files: + if self.stop_flag: + self.logger.info('Binlog parsing stopped') + else: self.log_file, self.log_pos = max(done_files) self.handled_binlog_files = self.handled_binlog_files | set([binlog_file for binlog_file, _ in done_files]) @@ -340,8 +340,3 @@ def follow_event_stream(self): if len(binlog_files) == 1: self.redis_data.set_last_catchup_timestamp(last_binlog_check_timestamp) - - elif self.stop_flag: - self.logger.info('Binlog parsing stopped') - else: - self.logger.error('Binlog parsing failed') diff --git a/src/sbosc/monitor/monitor.py b/src/sbosc/monitor/monitor.py index 21794b9..b6cbf07 100644 --- a/src/sbosc/monitor/monitor.py +++ b/src/sbosc/monitor/monitor.py @@ -306,10 +306,13 @@ def check_migration_status(self): if last_pk_inserted and last_pk_inserted >= chunk_info.start_pk: inserted_rows += last_pk_inserted - chunk_info.start_pk - if self.redis_data.metadata.max_id: - bulk_import_progress = inserted_rows / self.redis_data.metadata.max_id * 100 + if self.redis_data.metadata.max_pk: + bulk_import_progress = inserted_rows / self.redis_data.metadata.max_pk * 100 self.metric_sender.submit('sb_osc_bulk_import_progress', bulk_import_progress) + if config.DISABLE_EVENTHANDLER: + return + self.submit_event_handler_timestamps() # remaining_binlog_size diff --git a/src/sbosc/operations/base.py b/src/sbosc/operations/base.py index 752de21..b671898 100644 --- a/src/sbosc/operations/base.py +++ b/src/sbosc/operations/base.py @@ -12,13 +12,13 @@ def _insert_batch_query(self, start_pk, end_pk): INSERT INTO {self.destination_db}.{self.destination_table}({self.source_columns}) SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} AS source - WHERE id BETWEEN {start_pk} AND {end_pk} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} ''' def insert_batch(self, db, start_pk, end_pk, upsert=False, limit=None): query = self._insert_batch_query(start_pk, end_pk) if limit: - query = operation_utils.apply_limit(query, limit) + query = operation_utils.apply_limit(query, self.pk_column, limit) if upsert: query = operation_utils.insert_to_upsert(query, self.source_column_list) with db.cursor() as cursor: @@ -32,18 +32,29 @@ def apply_update(self, db, updated_pks): INSERT INTO {self.destination_db}.{self.destination_table}({self.source_columns}) SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({updated_pks_str}) + WHERE {self.pk_column} IN ({updated_pks_str}) ''' query = operation_utils.insert_to_upsert(query, self.source_column_list) cursor.execute(query) return cursor + def get_max_pk(self, db, start_pk, end_pk): + metadata = self.redis_data.metadata + with db.cursor(host='dest') as cursor: + cursor: Cursor + cursor.execute(f''' + SELECT MAX({self.pk_column}) FROM {metadata.destination_db}.{metadata.destination_table} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} + ''') + return cursor.fetchone()[0] + def _get_not_imported_pks_query(self, start_pk, end_pk): return f''' - SELECT source.id FROM {self.source_db}.{self.source_table} AS source - LEFT JOIN {self.destination_db}.{self.destination_table} AS dest ON source.id = dest.id - WHERE source.id BETWEEN {start_pk} AND {end_pk} - AND dest.id IS NULL + SELECT source.{self.pk_column} FROM {self.source_db}.{self.source_table} AS source + LEFT JOIN {self.destination_db}.{self.destination_table} AS dest + ON source.{self.pk_column} = dest.{self.pk_column} + WHERE source.{self.pk_column} BETWEEN {start_pk} AND {end_pk} + AND dest.{self.pk_column} IS NULL ''' def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): @@ -59,10 +70,11 @@ def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): if event_pks: event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' - SELECT source.id FROM {self.source_db}.{self.source_table} AS source - LEFT JOIN {self.destination_db}.{self.destination_table} AS dest ON source.id = dest.id - WHERE source.id IN ({event_pks_str}) - AND dest.id IS NULL + SELECT source.{self.pk_column} FROM {self.source_db}.{self.source_table} AS source + LEFT JOIN {self.destination_db}.{self.destination_table} AS dest + ON source.{self.pk_column} = dest.{self.pk_column} + WHERE source.{self.pk_column} IN ({event_pks_str}) + AND dest.{self.pk_column} IS NULL ''') not_inserted_pks = [row[0] for row in source_cursor.fetchall()] return not_inserted_pks @@ -72,15 +84,15 @@ def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): if event_pks: event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' - SELECT combined.id + SELECT combined.{self.pk_column} FROM ( SELECT {self.source_columns}, 'source' AS table_type FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks_str}) + WHERE {self.pk_column} IN ({event_pks_str}) UNION ALL SELECT {self.source_columns}, 'destination' AS table_type FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks_str}) + WHERE {self.pk_column} IN ({event_pks_str}) ) AS combined GROUP BY {self.source_columns} HAVING COUNT(1) = 1 AND SUM(table_type = 'source') = 1 @@ -93,17 +105,18 @@ def get_rematched_updated_pks(self, db, not_updated_pks): with db.cursor() as cursor: cursor: Cursor cursor.execute(f''' - SELECT combined.id FROM ( + SELECT combined.{self.pk_column} FROM ( SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({not_updated_pks_str}) UNION ALL + WHERE {self.pk_column} IN ({not_updated_pks_str}) UNION ALL SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({not_updated_pks_str}) + WHERE {self.pk_column} IN ({not_updated_pks_str}) ) AS combined GROUP BY {self.source_columns} HAVING COUNT(*) = 2 ''') rematched_pks = set([row[0] for row in cursor.fetchall()]) # add deleted pks cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({not_updated_pks_str}) + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} IN ({not_updated_pks_str}) ''') remaining_pks = set([row[0] for row in cursor.fetchall()]) deleted_pks = not_updated_pks - remaining_pks @@ -115,14 +128,15 @@ def get_rematched_removed_pks(self, db, not_removed_pks): cursor: Cursor cursor.execute(f''' SELECT source_pk FROM {config.SBOSC_DB}.unmatched_rows WHERE source_pk NOT IN ( - SELECT id FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({not_removed_pks_str}) + SELECT {self.pk_column} FROM {self.destination_db}.{self.destination_table} + WHERE {self.pk_column} IN ({not_removed_pks_str}) ) AND source_pk IN ({not_removed_pks_str}) AND migration_id = {self.migration_id} ''') rematched_pks = set([row[0] for row in cursor.fetchall()]) # add reinserted pks cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({not_removed_pks_str}) + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} IN ({not_removed_pks_str}) ''') reinserted_pks = set([row[0] for row in cursor.fetchall()]) return rematched_pks | reinserted_pks @@ -133,13 +147,13 @@ def _select_batch_query(self, start_pk, end_pk): return f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} AS source - WHERE id BETWEEN {start_pk} AND {end_pk} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} ''' def insert_batch(self, db, start_pk, end_pk, upsert=False, limit=None): select_batch_query = self._select_batch_query(start_pk, end_pk) if limit: - select_batch_query = operation_utils.apply_limit(select_batch_query, limit) + select_batch_query = operation_utils.apply_limit(select_batch_query, self.pk_column, limit) with db.cursor(host='source', role='reader') as cursor: cursor.execute(select_batch_query) rows = cursor.fetchall() @@ -162,7 +176,7 @@ def apply_update(self, db, updated_pks): cursor: Cursor cursor.execute(f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({updated_pks_str}) + WHERE {self.pk_column} IN ({updated_pks_str}) ''') rows = cursor.fetchall() if rows: @@ -177,16 +191,28 @@ def apply_update(self, db, updated_pks): else: return cursor + def get_max_pk(self, db, start_pk, end_pk): + metadata = self.redis_data.metadata + with db.cursor(host='dest') as cursor: + cursor: Cursor + cursor.execute(f''' + SELECT MAX({self.pk_column}) FROM {metadata.destination_db}.{metadata.destination_table} + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} + ''') + return cursor.fetchone()[0] + + def _get_not_imported_pks_query(self, table, start_pk, end_pk): + return f''' + SELECT {self.pk_column} FROM {table} AS source + WHERE {self.pk_column} BETWEEN {start_pk} AND {end_pk} + ''' + def get_not_imported_pks(self, source_cursor, dest_cursor, start_pk, end_pk): - source_cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} - WHERE id BETWEEN {start_pk} AND {end_pk} - ''') + source_cursor.execute( + self._get_not_imported_pks_query(f'{self.source_db}.{self.source_table}', start_pk, end_pk)) source_pks = [row[0] for row in source_cursor.fetchall()] - dest_cursor.execute(f''' - SELECT id FROM {self.destination_db}.{self.destination_table} - WHERE id BETWEEN {start_pk} AND {end_pk} - ''') + dest_cursor.execute( + self._get_not_imported_pks_query(f'{self.destination_db}.{self.destination_table}', start_pk, end_pk)) dest_pks = [row[0] for row in dest_cursor.fetchall()] return list(set(source_pks) - set(dest_pks)) @@ -194,10 +220,15 @@ def get_not_inserted_pks(self, source_cursor, dest_cursor, event_pks): not_inserted_pks = [] if event_pks: event_pks_str = ','.join([str(pk) for pk in event_pks]) - source_cursor.execute(f"SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({event_pks_str})") + source_cursor.execute(f''' + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} IN ({event_pks_str}) + ''') source_pks = [row[0] for row in source_cursor.fetchall()] - dest_cursor.execute( - f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({event_pks_str})") + dest_cursor.execute(f''' + SELECT {self.pk_column} FROM {self.destination_db}.{self.destination_table} + WHERE {self.pk_column} IN ({event_pks_str}) + ''') dest_pks = [row[0] for row in dest_cursor.fetchall()] not_inserted_pks = list(set(source_pks) - set(dest_pks)) return not_inserted_pks @@ -208,18 +239,18 @@ def get_not_updated_pks(self, source_cursor, dest_cursor, event_pks): event_pks_str = ','.join([str(pk) for pk in event_pks]) source_cursor.execute(f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({event_pks_str}) + WHERE {self.pk_column} IN ({event_pks_str}) ''') source_df = pd.DataFrame(source_cursor.fetchall(), columns=[c[0] for c in source_cursor.description]) dest_cursor.execute(f''' SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({event_pks_str}) + WHERE {self.pk_column} IN ({event_pks_str}) ''') dest_df = pd.DataFrame(dest_cursor.fetchall(), columns=[c[0] for c in dest_cursor.description]) dest_df = dest_df[source_df.columns] - source_df.set_index('id', inplace=True) - dest_df.set_index('id', inplace=True) + source_df.set_index(self.pk_column.strip('`'), inplace=True) + dest_df.set_index(self.pk_column.strip('`'), inplace=True) common_index = dest_df.index.intersection(source_df.index) source_df = source_df.loc[common_index] dest_df = dest_df.loc[common_index] @@ -236,25 +267,26 @@ def get_rematched_updated_pks(self, db, not_updated_pks): cursor: Cursor cursor.execute(f''' SELECT {self.source_columns} FROM {self.source_db}.{self.source_table} - WHERE id IN ({not_updated_pks_str}) + WHERE {self.pk_column} IN ({not_updated_pks_str}) ''') source_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) with db.cursor(host='dest', role='reader') as cursor: cursor: Cursor cursor.execute(f''' SELECT {self.source_columns} FROM {self.destination_db}.{self.destination_table} - WHERE id IN ({not_updated_pks_str}) + WHERE {self.pk_column} IN ({not_updated_pks_str}) ''') dest_df = pd.DataFrame(cursor.fetchall(), columns=[c[0] for c in cursor.description]) dest_df = dest_df.astype(source_df.dtypes.to_dict()) merged_df = source_df.merge(dest_df, how='inner', on=source_df.columns.tolist(), indicator=True) - rematched_pks = set(merged_df[merged_df['_merge'] == 'both']['id'].tolist()) + rematched_pks = set(merged_df[merged_df['_merge'] == 'both'][self.pk_column.strip('`')].tolist()) except pd.errors.IntCastingNaNError: rematched_pks = set() # add deleted pks with db.cursor(host='source', role='reader') as cursor: cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({not_updated_pks_str}) + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} IN ({not_updated_pks_str}) ''') remaining_pks = set([row[0] for row in cursor.fetchall()]) deleted_pks = not_updated_pks - remaining_pks @@ -264,8 +296,10 @@ def get_rematched_removed_pks(self, db, not_removed_pks): not_removed_pks_str = ','.join([str(pk) for pk in not_removed_pks]) with db.cursor(host='dest', role='reader') as cursor: cursor: Cursor - cursor.execute( - f"SELECT id FROM {self.destination_db}.{self.destination_table} WHERE id IN ({not_removed_pks_str})") + cursor.execute(f''' + SELECT {self.pk_column} FROM {self.destination_db}.{self.destination_table} + WHERE {self.pk_column} IN ({not_removed_pks_str}) + ''') still_not_removed_pks_str = ','.join([str(row[0]) for row in cursor.fetchall()]) with db.cursor(host='source', role='reader') as cursor: cursor: Cursor @@ -279,7 +313,8 @@ def get_rematched_removed_pks(self, db, not_removed_pks): rematched_pks = set([row[0] for row in cursor.fetchall()]) # add reinserted pks cursor.execute(f''' - SELECT id FROM {self.source_db}.{self.source_table} WHERE id IN ({not_removed_pks_str}) + SELECT {self.pk_column} FROM {self.source_db}.{self.source_table} + WHERE {self.pk_column} IN ({not_removed_pks_str}) ''') reinserted_pks = set([row[0] for row in cursor.fetchall()]) return rematched_pks | reinserted_pks diff --git a/src/sbosc/operations/operation.py b/src/sbosc/operations/operation.py index 327e74d..8148b75 100644 --- a/src/sbosc/operations/operation.py +++ b/src/sbosc/operations/operation.py @@ -30,6 +30,7 @@ def __init__(self, migration_id): self.destination_table = metadata.destination_table self.source_columns: str = metadata.source_columns self.source_column_list: list = metadata.source_columns.split(',') + self.pk_column = metadata.pk_column self.start_datetime = metadata.start_datetime self.operation_config = self.operation_config_class(**config.OPERATION_CLASS_CONFIG) @@ -49,6 +50,14 @@ def apply_update(self, db: Database, updated_pks: list) -> Cursor: """ pass + @abstractmethod + def get_max_pk(self, db: Database, start_pk: int, end_pk: int) -> int: + """ + Returns the maximum primary key in the destination table. + Used when chunk status is DUPLICATE_KEY to determine starting batch range. + """ + pass + @abstractmethod def get_not_imported_pks(self, source_cursor: Cursor, dest_cursor: Cursor, start_pk: int, end_pk: int) -> list: """ diff --git a/src/sbosc/operations/utils.py b/src/sbosc/operations/utils.py index 1577e17..5f4be17 100644 --- a/src/sbosc/operations/utils.py +++ b/src/sbosc/operations/utils.py @@ -6,5 +6,5 @@ def insert_to_upsert(query: str, source_columns: list) -> str: return query -def apply_limit(query: str, limit: int) -> str: - return query + f" ORDER BY id LIMIT {limit}" +def apply_limit(query: str, pk_column: str, limit: int) -> str: + return query + f" ORDER BY {pk_column} LIMIT {limit}" diff --git a/src/sbosc/worker/worker.py b/src/sbosc/worker/worker.py index 1b649f1..44b96e7 100644 --- a/src/sbosc/worker/worker.py +++ b/src/sbosc/worker/worker.py @@ -100,7 +100,7 @@ def get_start_pk(self, chunk_info: ChunkInfo): elif chunk_info.status == ChunkStatus.IN_PROGRESS: return chunk_info.last_pk_inserted + 1 elif chunk_info.status == ChunkStatus.DUPLICATE_KEY: - max_pk = self.get_max_pk(chunk_info.start_pk, chunk_info.end_pk) + max_pk = self.migration_operation.get_max_pk(self.db, chunk_info.start_pk, chunk_info.end_pk) return max_pk + 1 def bulk_import(self): @@ -144,7 +144,11 @@ def bulk_import(self): self.worker_config.update_batch_size_multiplier(cursor.rowcount) # update last pk inserted - if cursor.rowcount == self.worker_config.raw_batch_size: + # If batch size multiplier is used, + # there can be remaining rows between cursor.lastrowid and batch_end_pk + # because of the limit clause in the query. + # Note that cursor.lastrowid is a non-zero value only if pk is auto-incremented. + if self.use_batch_size_multiplier and cursor.rowcount == self.worker_config.raw_batch_size: last_pk_inserted = cursor.lastrowid else: last_pk_inserted = batch_end_pk @@ -206,16 +210,6 @@ def apply_dml_events(self): except Exception as e: self.logger.error(e) - def get_max_pk(self, start_pk, end_pk): - metadata = self.redis_data.metadata - with self.db.cursor(host='dest') as cursor: - cursor: Cursor - cursor.execute(f''' - SELECT MAX(id) FROM {metadata.destination_db}.{metadata.destination_table} - WHERE id BETWEEN {start_pk} AND {end_pk} - ''') - return cursor.fetchone()[0] - @staticmethod def calculate_metrics(func: Callable[..., Cursor]): def wrapper(self: Self, *args, **kwargs): @@ -247,7 +241,7 @@ def apply_delete(self, removed_pks): removed_pks_str = ",".join([str(pk) for pk in removed_pks]) query = f""" DELETE FROM {metadata.destination_db}.{metadata.destination_table} - WHERE id IN ({removed_pks_str}) + WHERE {metadata.pk_column} IN ({removed_pks_str}) """ cursor.execute(query) return cursor diff --git a/tests/conftest.py b/tests/conftest.py index 7c82f5c..e07fc70 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,8 +89,9 @@ def init_migration(config, cursor, redis_data, migration_id): cursor.execute(f'DROP DATABASE IF EXISTS {db}') cursor.execute(f'CREATE DATABASE {db}') - cursor.execute(f'CREATE TABLE {config.SOURCE_DB}.{config.SOURCE_TABLE} (id int)') - cursor.execute(f'CREATE TABLE {config.DESTINATION_DB}.{config.DESTINATION_TABLE} (id int)') + cursor.execute(f'CREATE TABLE {config.SOURCE_DB}.{config.SOURCE_TABLE} (id int AUTO_INCREMENT PRIMARY KEY)') + cursor.execute(f'INSERT INTO {config.SOURCE_DB}.{config.SOURCE_TABLE} VALUES (1)') + cursor.execute(f'CREATE TABLE {config.DESTINATION_DB}.{config.DESTINATION_TABLE} (id int AUTO_INCREMENT PRIMARY KEY)') retrieved_migration_id = Initializer().init_migration() diff --git a/tests/test_eventhandler.py b/tests/test_eventhandler.py index 5fce90b..705dd72 100644 --- a/tests/test_eventhandler.py +++ b/tests/test_eventhandler.py @@ -97,7 +97,7 @@ def test_event_handler_save_to_database(event_handler, cursor, redis_data): time.sleep(100) total_events = 1000 - redis_data.metadata.max_id = total_events + redis_data.metadata.max_pk = total_events insert_events = total_events // 2 update_events = (total_events - insert_events) // 2 delete_events = total_events - insert_events - update_events diff --git a/tests/test_monitor.py b/tests/test_monitor.py index a9a566d..2811d71 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -113,7 +113,7 @@ def test_update_worker_config(monitor, redis_data): def test_check_migration_status(monitor, cursor, redis_data): cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.event_handler_status") cursor.execute(f"TRUNCATE TABLE {config.SBOSC_DB}.apply_dml_events_status") - monitor.redis_data.metadata.max_id = 0 + monitor.redis_data.metadata.max_pk = 0 monitor.check_migration_status() metric_set = get_metric_names(monitor) expected_metrics = { @@ -124,7 +124,7 @@ def test_check_migration_status(monitor, cursor, redis_data): } assert metric_set == expected_metrics - monitor.redis_data.metadata.max_id = 100 + monitor.redis_data.metadata.max_pk = 100 monitor.check_migration_status() metric_set = get_metric_names(monitor) expected_metrics.add('sb_osc_bulk_import_progress')