diff --git a/data-prepper-plugins/rds-source/build.gradle b/data-prepper-plugins/rds-source/build.gradle index 0203cccf10..77f1022f63 100644 --- a/data-prepper-plugins/rds-source/build.gradle +++ b/data-prepper-plugins/rds-source/build.gradle @@ -24,6 +24,9 @@ dependencies { implementation 'com.zendesk:mysql-binlog-connector-java:0.29.2' implementation 'com.mysql:mysql-connector-j:8.4.0' + compileOnly 'org.projectlombok:lombok:1.18.20' + annotationProcessor 'org.projectlombok:lombok:1.18.20' + testImplementation project(path: ':data-prepper-test-common') testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' testImplementation project(path: ':data-prepper-test-event') diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactory.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactory.java index 419f1bf805..f6d0a26033 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactory.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactory.java @@ -12,6 +12,7 @@ import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ResyncPartition; import java.util.function.Function; @@ -34,6 +35,8 @@ public EnhancedSourcePartition apply(SourcePartitionStoreItem partitionStoreItem return new DataFilePartition(partitionStoreItem); case StreamPartition.PARTITION_TYPE: return new StreamPartition(partitionStoreItem); + case ResyncPartition.PARTITION_TYPE: + return new ResyncPartition(partitionStoreItem); default: // Unable to acquire other partitions. return new GlobalState(partitionStoreItem); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/partition/ResyncPartition.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/partition/ResyncPartition.java new file mode 100644 index 0000000000..5e4321119d --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/partition/ResyncPartition.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.coordination.partition; + +import org.opensearch.dataprepper.model.source.coordinator.SourcePartitionStoreItem; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.ResyncProgressState; + +import java.util.Optional; + +public class ResyncPartition extends EnhancedSourcePartition { + + public static final String PARTITION_TYPE = "RESYNC"; + + private final String database; + private final String table; + private final long timestamp; + private final ResyncProgressState state; + + public ResyncPartition(String database, String table, long timestamp, ResyncProgressState state) { + this.database = database; + this.table = table; + this.timestamp = timestamp; + this.state = state; + } + + public ResyncPartition(SourcePartitionStoreItem sourcePartitionStoreItem) { + setSourcePartitionStoreItem(sourcePartitionStoreItem); + String[] keySplits = sourcePartitionStoreItem.getSourcePartitionKey().split("\\|"); + database = keySplits[0]; + table = keySplits[1]; + timestamp = Long.parseLong(keySplits[2]); + state = convertStringToPartitionProgressState(ResyncProgressState.class, sourcePartitionStoreItem.getPartitionProgressState()); + } + + @Override + public String getPartitionType() { + return PARTITION_TYPE; + } + + @Override + public String getPartitionKey() { + return database + "|" + table + "|" + timestamp; + } + + @Override + public Optional getProgressState() { + if (state != null) { + return Optional.of(state); + } + return Optional.empty(); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/ResyncProgressState.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/ResyncProgressState.java new file mode 100644 index 0000000000..6bdb18b6a8 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/ResyncProgressState.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.coordination.state; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +@Setter +@Getter +public class ResyncProgressState { + @JsonProperty("foreignKeyName") + private String foreignKeyName; + + @JsonProperty("updatedValue") + private Object updatedValue; + + @JsonProperty("primaryKeys") + private List primaryKeys; +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java index 81a3c7f5ac..1f751e2087 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/coordination/state/StreamProgressState.java @@ -7,6 +7,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; + +import java.util.List; public class StreamProgressState { @@ -16,6 +19,9 @@ public class StreamProgressState { @JsonProperty("waitForExport") private boolean waitForExport = false; + @JsonProperty("foreignKeyRelations") + private List foreignKeyRelations; + public BinlogCoordinate getCurrentPosition() { return currentPosition; } @@ -31,4 +37,12 @@ public boolean shouldWaitForExport() { public void setWaitForExport(boolean waitForExport) { this.waitForExport = waitForExport; } + + public List getForeignKeyRelations() { + return foreignKeyRelations; + } + + public void setForeignKeyRelations(List foreignKeyRelations) { + this.foreignKeyRelations = foreignKeyRelations; + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoader.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoader.java index 5e0fe9ecf3..df25b5c52c 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoader.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoader.java @@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicLong; import static org.opensearch.dataprepper.logging.DataPrepperMarkers.SENSITIVE; +import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER; public class DataFileLoader implements Runnable { @@ -116,7 +117,7 @@ public void run() { DataFileProgressState progressState = dataFilePartition.getProgressState().get(); - final String fullTableName = progressState.getSourceDatabase() + "." + progressState.getSourceTable(); + final String fullTableName = progressState.getSourceDatabase() + DOT_DELIMITER + progressState.getSourceTable(); final List primaryKeys = progressState.getPrimaryKeyMap().getOrDefault(fullTableName, List.of()); final long snapshotTime = progressState.getSnapshotTime(); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java index ed80f136dc..fbf278ac8e 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/leader/LeaderScheduler.java @@ -160,6 +160,7 @@ private void createStreamPartition(RdsSourceConfig sourceConfig) { final StreamProgressState progressState = new StreamProgressState(); progressState.setWaitForExport(sourceConfig.isExportEnabled()); getCurrentBinlogPosition().ifPresent(progressState::setCurrentPosition); + progressState.setForeignKeyRelations(schemaManager.getForeignKeyRelations(sourceConfig.getTableNames())); StreamPartition streamPartition = new StreamPartition(sourceConfig.getDbIdentifier(), progressState); sourceCoordinator.createPartition(streamPartition); } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyAction.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyAction.java new file mode 100644 index 0000000000..eb986bbc87 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyAction.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.model; + +import java.sql.DatabaseMetaData; +import java.util.Set; + +public enum ForeignKeyAction { + CASCADE, + NO_ACTION, + RESTRICT, + SET_DEFAULT, + SET_NULL, + UNKNOWN; + + private static final Set CASCADING_ACTIONS = Set.of(CASCADE, SET_DEFAULT, SET_NULL); + /** + * Returns the corresponding ForeignKeyAction for the given metadata action value. + * + * @param action the metadata action value + * @return the corresponding ForeignKeyAction + */ + public static ForeignKeyAction getActionFromMetadata(short action) { + switch (action) { + case DatabaseMetaData.importedKeyCascade: + return CASCADE; + case DatabaseMetaData.importedKeySetNull: + return SET_NULL; + case DatabaseMetaData.importedKeySetDefault: + return SET_DEFAULT; + case DatabaseMetaData.importedKeyRestrict: + return RESTRICT; + case DatabaseMetaData.importedKeyNoAction: + return NO_ACTION; + default: + return UNKNOWN; + } + } + + /** + * Checks if the foreign key action is one of the cascading actions (CASCADE, SET_DEFAULT, SET_NULL) + * that will result in changes to the foreign key value when referenced key in parent table changes. + * + * @param foreignKeyAction the foreign key action + * @return true if the foreign key action is a cascade action, false otherwise + */ + public static boolean isCascadingAction(ForeignKeyAction foreignKeyAction) { + if (foreignKeyAction == null) { + return false; + } + return CASCADING_ACTIONS.contains(foreignKeyAction); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyRelation.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyRelation.java new file mode 100644 index 0000000000..925fef1585 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ForeignKeyRelation.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.model; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Builder; +import lombok.Getter; + +@Getter +@Builder +public class ForeignKeyRelation { + // TODO: add java docs + @JsonProperty("database_name") + private final String databaseName; + + @JsonProperty("parent_table_name") + private final String parentTableName; + + @JsonProperty("referenced_key_name") + private final String referencedKeyName; + + @JsonProperty("child_table_name") + private final String childTableName; + + @JsonProperty("foreign_key_name") + private final String foreignKeyName; + + @JsonProperty("foreign_key_default_value") + @Builder.Default + private Object foreignKeyDefaultValue = null; + + @JsonProperty("update_action") + private final ForeignKeyAction updateAction; + + @JsonProperty("delete_action") + private final ForeignKeyAction deleteAction; + + @JsonCreator + public ForeignKeyRelation(@JsonProperty("database_name") String databaseName, + @JsonProperty("parent_table_name") String parentTableName, + @JsonProperty("referenced_key_name") String referencedKeyName, + @JsonProperty("child_table_name") String childTableName, + @JsonProperty("foreign_key_name") String foreignKeyName, + @JsonProperty("foreign_key_default_value") Object foreignKeyDefaultValue, + @JsonProperty("update_action") ForeignKeyAction updateAction, + @JsonProperty("delete_action") ForeignKeyAction deleteAction) { + this.databaseName = databaseName; + this.parentTableName = parentTableName; + this.referencedKeyName = referencedKeyName; + this.childTableName = childTableName; + this.foreignKeyName = foreignKeyName; + this.foreignKeyDefaultValue = foreignKeyDefaultValue; + this.updateAction = updateAction; + this.deleteAction = deleteAction; + } + + /** + * Checks either update action or delete action is one of the cascading actions (CASCADE, SET_DEFAULT, SET_NULL). + * + * @param foreignKeyRelation The foreign key relation. + * @return True if the foreign key relation contains a cascade action, false otherwise. + */ + public static boolean containsCascadingAction(ForeignKeyRelation foreignKeyRelation) { + return ForeignKeyAction.isCascadingAction(foreignKeyRelation.getUpdateAction()) || + ForeignKeyAction.isCascadingAction(foreignKeyRelation.getDeleteAction()); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ParentTable.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ParentTable.java new file mode 100644 index 0000000000..c25a039295 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/ParentTable.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.model; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A data model for a parent table in a foreign key relationship + */ +@Getter +@Builder +public class ParentTable { + private final String databaseName; + private final String tableName; + /** + * Column name to a list of ForeignKeyRelation in which the column is referenced + */ + private final Map> referencedColumnMap; + + @Getter(AccessLevel.NONE) + @Builder.Default + private Map> columnsWithCascadingUpdate = null; + + @Getter(AccessLevel.NONE) + @Builder.Default + private Map> columnsWithCascadingDelete = null; + + /** + * Returns a map of column name to a list of ForeignKeyRelation in which the column is referenced and the update action is cascading. + * @return a map of column name to a list of ForeignKeyRelation + */ + public Map> getColumnsWithCascadingUpdate() { + if (columnsWithCascadingUpdate != null) { + return columnsWithCascadingUpdate; + } + + columnsWithCascadingUpdate = new HashMap<>(); + for (String column : referencedColumnMap.keySet()) { + for (ForeignKeyRelation foreignKeyRelation : referencedColumnMap.get(column)) { + if (ForeignKeyAction.isCascadingAction(foreignKeyRelation.getUpdateAction())) { + if (!columnsWithCascadingUpdate.containsKey(column)) { + columnsWithCascadingUpdate.put(column, new ArrayList<>()); + } + columnsWithCascadingUpdate.get(column).add(foreignKeyRelation); + } + } + } + return columnsWithCascadingUpdate; + } + + /** + * Returns a map of column name to a list of ForeignKeyRelation in which the column is referenced and the delete action is cascading. + * @return a map of column name to a list of ForeignKeyRelation + */ + public Map> getColumnsWithCascadingDelete() { + if (columnsWithCascadingDelete != null) { + return columnsWithCascadingDelete; + } + + columnsWithCascadingDelete = new HashMap<>(); + for (String column : referencedColumnMap.keySet()) { + for (ForeignKeyRelation foreignKeyRelation : referencedColumnMap.get(column)) { + if (ForeignKeyAction.isCascadingAction(foreignKeyRelation.getDeleteAction())) { + if (!columnsWithCascadingDelete.containsKey(column)) { + columnsWithCascadingDelete.put(column, new ArrayList<>()); + } + columnsWithCascadingDelete.get(column).add(foreignKeyRelation); + } + } + } + return columnsWithCascadingDelete; + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/TableMetadata.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/TableMetadata.java index 310919c8ca..b3bf900eb8 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/TableMetadata.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/model/TableMetadata.java @@ -10,6 +10,8 @@ import java.util.Map; public class TableMetadata { + public static final String DOT_DELIMITER = "."; + private String databaseName; private String tableName; private List columnNames; @@ -40,7 +42,7 @@ public String getTableName() { } public String getFullTableName() { - return databaseName + "." + tableName; + return databaseName + DOT_DELIMITER + tableName; } public List getColumnNames() { diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java new file mode 100644 index 0000000000..98df985ccb --- /dev/null +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.resync; + +import com.github.shyiko.mysql.binlog.event.Event; +import com.github.shyiko.mysql.binlog.event.UpdateRowsEventData; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ResyncPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.ResyncProgressState; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; +import org.opensearch.dataprepper.plugins.source.rds.model.ParentTable; +import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER; + +public class CascadingActionDetector { + + private static final Logger LOG = LoggerFactory.getLogger(CascadingActionDetector.class); + + private final EnhancedSourceCoordinator sourceCoordinator; + + public CascadingActionDetector(final EnhancedSourceCoordinator sourceCoordinator) { + this.sourceCoordinator = sourceCoordinator; + } + + /** + * Gets TableName to ParentTable mapping from given stream partition. + * Only parent tables that have cascading update/delete actions defined (CASCADE, SET_NULL, SET_DEFAULT) are included in this map. + * @param streamPartition stream partition + * @return A map from TableName to ParentTable + */ + public Map getParentTableMap(StreamPartition streamPartition) { + final Map parentTableMap = new HashMap<>(); + if (streamPartition.getProgressState().isEmpty()) { + return parentTableMap; + } + + List foreignKeyRelations = streamPartition.getProgressState().get().getForeignKeyRelations();; + + for (ForeignKeyRelation foreignKeyRelation : foreignKeyRelations) { + if (!ForeignKeyRelation.containsCascadingAction(foreignKeyRelation)) { + // skip foreign key relations without cascading actions + continue; + } + + final String fullParentTableName = getFullTableName(foreignKeyRelation.getDatabaseName(), foreignKeyRelation.getParentTableName()); + ParentTable parentTable; + if (!parentTableMap.containsKey(fullParentTableName)) { + Map> referencedColumnMap = new HashMap<>(); + referencedColumnMap.put(foreignKeyRelation.getReferencedKeyName(), new ArrayList<>(List.of(foreignKeyRelation))); + parentTable = ParentTable.builder() + .databaseName(foreignKeyRelation.getDatabaseName()) + .tableName(foreignKeyRelation.getParentTableName()) + .referencedColumnMap(referencedColumnMap) + .build(); + parentTableMap.put(fullParentTableName, parentTable); + } else { + parentTable = parentTableMap.get(fullParentTableName); + if (!parentTable.getReferencedColumnMap().containsKey(foreignKeyRelation.getReferencedKeyName())) { + parentTable.getReferencedColumnMap().put(foreignKeyRelation.getReferencedKeyName(), new ArrayList<>()); + } + parentTable.getReferencedColumnMap().get(foreignKeyRelation.getReferencedKeyName()).add(foreignKeyRelation); + } + } + LOG.debug("ParentTables are {}", parentTableMap.keySet()); + return parentTableMap; + } + + /** + * Detects if a binlog event contains cascading updates and if detected, creates resync partitions + */ + public void detectCascadingUpdates(Event event, Map parentTableMap, TableMetadata tableMetadata) { + final UpdateRowsEventData data = event.getData(); + if (parentTableMap.containsKey(tableMetadata.getFullTableName())) { + final ParentTable parentTable = parentTableMap.get(tableMetadata.getFullTableName()); + + for (Map.Entry row : data.getRows()) { + // Find out for this row, which columns are changing + LOG.debug("Checking for updated columns"); + final Map updatedColumnsAndValues = IntStream.range(0, row.getKey().length) + .filter(i -> !row.getKey()[i].equals(row.getValue()[i])) + .mapToObj(i -> tableMetadata.getColumnNames().get(i)) + .collect(Collectors.toMap( + column -> column, + column -> row.getValue()[tableMetadata.getColumnNames().indexOf(column)] + )); + LOG.debug("These columns were updated: {}", updatedColumnsAndValues); + + LOG.debug("Decide whether to create resync partitions"); + // Create resync partition if changing columns are associated with cascading update + for (String column : updatedColumnsAndValues.keySet()) { + if (parentTable.getColumnsWithCascadingUpdate().containsKey(column)) { + for (ForeignKeyRelation foreignKeyRelation : parentTable.getColumnsWithCascadingUpdate().get(column)) { + if (foreignKeyRelation.getUpdateAction() == ForeignKeyAction.CASCADE) { + createResyncPartition( + foreignKeyRelation.getDatabaseName(), + foreignKeyRelation.getChildTableName(), + foreignKeyRelation.getForeignKeyName(), + updatedColumnsAndValues.get(column), + tableMetadata.getPrimaryKeys(), + event.getHeader().getTimestamp()); + } else if (foreignKeyRelation.getUpdateAction() == ForeignKeyAction.SET_NULL) { + createResyncPartition( + foreignKeyRelation.getDatabaseName(), + foreignKeyRelation.getChildTableName(), + foreignKeyRelation.getForeignKeyName(), + null, + tableMetadata.getPrimaryKeys(), + event.getHeader().getTimestamp()); + } else if (foreignKeyRelation.getUpdateAction() == ForeignKeyAction.SET_DEFAULT) { + createResyncPartition( + foreignKeyRelation.getDatabaseName(), + foreignKeyRelation.getChildTableName(), + foreignKeyRelation.getForeignKeyName(), + foreignKeyRelation.getForeignKeyDefaultValue(), + tableMetadata.getPrimaryKeys(), + event.getHeader().getTimestamp()); + } + } + } + } + } + } + } + + /** + * Detects if a binlog event contains cascading deletes and if detected, creates resync partitions + */ + public void detectCascadingDeletes(Event event, Map parentTableMap, TableMetadata tableMetadata) { + if (parentTableMap.containsKey(tableMetadata.getFullTableName())) { + final ParentTable parentTable = parentTableMap.get(tableMetadata.getFullTableName()); + + for (String column : parentTable.getColumnsWithCascadingDelete().keySet()) { + for (ForeignKeyRelation foreignKeyRelation : parentTable.getColumnsWithCascadingDelete().get(column)) { + if (foreignKeyRelation.getDeleteAction() == ForeignKeyAction.CASCADE) { + LOG.warn("Cascade delete is not supported yet"); + } else if (foreignKeyRelation.getDeleteAction() == ForeignKeyAction.SET_NULL) { + // foreign key in the child table will be set to NULL + createResyncPartition( + foreignKeyRelation.getDatabaseName(), + foreignKeyRelation.getChildTableName(), + foreignKeyRelation.getForeignKeyName(), + null, + tableMetadata.getPrimaryKeys(), + event.getHeader().getTimestamp()); + } else if (foreignKeyRelation.getDeleteAction() == ForeignKeyAction.SET_DEFAULT) { + createResyncPartition( + foreignKeyRelation.getDatabaseName(), + foreignKeyRelation.getChildTableName(), + foreignKeyRelation.getForeignKeyName(), + foreignKeyRelation.getForeignKeyDefaultValue(), + tableMetadata.getPrimaryKeys(), + event.getHeader().getTimestamp()); + } + } + } + } + } + + private String getFullTableName(String database, String table) { + return database + DOT_DELIMITER + table; + } + + private void createResyncPartition(String database, String childTable, String foreignKeyName, Object updatedValue, List primaryKeys, long eventTimestampMillis) { + LOG.debug("Create Resyc partition for table {} and column {} with new value {}", childTable, foreignKeyName, updatedValue); + final ResyncProgressState progressState = new ResyncProgressState(); + progressState.setForeignKeyName(foreignKeyName); + progressState.setUpdatedValue(updatedValue); + progressState.setPrimaryKeys(primaryKeys); + + final ResyncPartition resyncPartition = new ResyncPartition(database, childTable, eventTimestampMillis, progressState); + sourceCoordinator.createPartition(resyncPartition); + } +} diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java index 1ed848beee..bbe01ba160 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManager.java @@ -6,6 +6,8 @@ package org.opensearch.dataprepper.plugins.source.rds.schema; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,6 +23,8 @@ public class SchemaManager { private static final Logger LOG = LoggerFactory.getLogger(SchemaManager.class); + + static final String[] TABLE_TYPES = new String[]{"TABLE"}; static final String COLUMN_NAME = "COLUMN_NAME"; static final String BINLOG_STATUS_QUERY = "SHOW MASTER STATUS"; static final String BINLOG_FILE = "File"; @@ -28,6 +32,13 @@ public class SchemaManager { static final int NUM_OF_RETRIES = 3; static final int BACKOFF_IN_MILLIS = 500; static final String TYPE_NAME = "TYPE_NAME"; + static final String FKTABLE_NAME = "FKTABLE_NAME"; + static final String FKCOLUMN_NAME = "FKCOLUMN_NAME"; + static final String PKTABLE_NAME = "PKTABLE_NAME"; + static final String PKCOLUMN_NAME = "PKCOLUMN_NAME"; + static final String UPDATE_RULE = "UPDATE_RULE"; + static final String DELETE_RULE = "DELETE_RULE"; + static final String COLUMN_DEF = "COLUMN_DEF"; private final ConnectionManager connectionManager; public SchemaManager(ConnectionManager connectionManager) { @@ -101,6 +112,70 @@ public Optional getCurrentBinaryLogPosition() { return Optional.empty(); } + /** + * Get the foreign key relations associated with the given tables. + * + * @param tableNames the table names + * @return the foreign key relations + */ + public List getForeignKeyRelations(List tableNames) { + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + try (final Connection connection = connectionManager.getConnection()) { + final List foreignKeyRelations = new ArrayList<>(); + DatabaseMetaData metaData = connection.getMetaData(); + for (final String tableName : tableNames) { + String database = tableName.split("\\.")[0]; + String table = tableName.split("\\.")[1]; + ResultSet tableResult = metaData.getTables(database, null, table, TABLE_TYPES); + while (tableResult.next()) { + ResultSet foreignKeys = metaData.getImportedKeys(database, null, table); + + while (foreignKeys.next()) { + String fkTableName = foreignKeys.getString(FKTABLE_NAME); + String fkColumnName = foreignKeys.getString(FKCOLUMN_NAME); + String pkTableName = foreignKeys.getString(PKTABLE_NAME); + String pkColumnName = foreignKeys.getString(PKCOLUMN_NAME); + ForeignKeyAction updateAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(UPDATE_RULE)); + ForeignKeyAction deleteAction = ForeignKeyAction.getActionFromMetadata(foreignKeys.getShort(DELETE_RULE)); + + Object defaultValue = null; + if (updateAction == ForeignKeyAction.SET_DEFAULT || deleteAction == ForeignKeyAction.SET_DEFAULT) { + // Get column default + ResultSet columnResult = metaData.getColumns(database, null, table, fkColumnName); + + if (columnResult.next()) { + defaultValue = columnResult.getObject(COLUMN_DEF); + } + } + + ForeignKeyRelation foreignKeyRelation = ForeignKeyRelation.builder() + .databaseName(database) + .parentTableName(pkTableName) + .referencedKeyName(pkColumnName) + .childTableName(fkTableName) + .foreignKeyName(fkColumnName) + .foreignKeyDefaultValue(defaultValue) + .updateAction(updateAction) + .deleteAction(deleteAction) + .build(); + + foreignKeyRelations.add(foreignKeyRelation); + } + } + } + + return foreignKeyRelations; + } catch (Exception e) { + LOG.error("Failed to scan foreign key references, retrying", e); + } + applyBackoff(); + retry++; + } + LOG.warn("Failed to scan foreign key references"); + return List.of(); + } + private void applyBackoff() { try { Thread.sleep(BACKOFF_IN_MILLIS); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java index 6bf3800337..34eecbca5e 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java @@ -29,8 +29,11 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; import org.opensearch.dataprepper.plugins.source.rds.converter.StreamRecordConverter; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.ParentTable; import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; +import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,6 +67,13 @@ public class BinlogEventListener implements BinaryLogClient.EventListener { */ private final Map tableMetadataMap; + /** + * TableName to ParentTable mapping. Only parent tables that have cascading update/delete actions defined + * (CASCADE, SET_NULL, SET_DEFAULT) are included in this map. + */ + private final Map parentTableMap; + + private final StreamPartition streamPartition; private final StreamRecordConverter recordConverter; private final BinaryLogClient binaryLogClient; private final Buffer> buffer; @@ -74,6 +84,8 @@ public class BinlogEventListener implements BinaryLogClient.EventListener { private final List pipelineEvents; private final StreamCheckpointManager streamCheckpointManager; private final ExecutorService binlogEventExecutorService; + private final CascadingActionDetector cascadeActionDetector; + private final Counter changeEventSuccessCounter; private final Counter changeEventErrorCounter; private final DistributionSummary bytesReceivedSummary; @@ -86,13 +98,16 @@ public class BinlogEventListener implements BinaryLogClient.EventListener { */ private BinlogCoordinate currentBinlogCoordinate; - public BinlogEventListener(final Buffer> buffer, + public BinlogEventListener(final StreamPartition streamPartition, + final Buffer> buffer, final RdsSourceConfig sourceConfig, final String s3Prefix, final PluginMetrics pluginMetrics, final BinaryLogClient binaryLogClient, final StreamCheckpointer streamCheckpointer, - final AcknowledgementSetManager acknowledgementSetManager) { + final AcknowledgementSetManager acknowledgementSetManager, + final CascadingActionDetector cascadeActionDetector) { + this.streamPartition = streamPartition; this.buffer = buffer; this.binaryLogClient = binaryLogClient; tableMetadataMap = new HashMap<>(); @@ -110,6 +125,9 @@ public BinlogEventListener(final Buffer> buffer, acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout()); streamCheckpointManager.start(); + this.cascadeActionDetector = cascadeActionDetector; + parentTableMap = cascadeActionDetector.getParentTableMap(streamPartition); + changeEventSuccessCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSED_COUNT); changeEventErrorCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSING_ERROR_COUNT); bytesReceivedSummary = pluginMetrics.summary(BYTES_RECEIVED); @@ -117,14 +135,16 @@ public BinlogEventListener(final Buffer> buffer, eventProcessingTimer = pluginMetrics.timer(REPLICATION_LOG_EVENT_PROCESSING_TIME); } - public static BinlogEventListener create(final Buffer> buffer, + public static BinlogEventListener create(final StreamPartition streamPartition, + final Buffer> buffer, final RdsSourceConfig sourceConfig, final String s3Prefix, final PluginMetrics pluginMetrics, final BinaryLogClient binaryLogClient, final StreamCheckpointer streamCheckpointer, - final AcknowledgementSetManager acknowledgementSetManager) { - return new BinlogEventListener(buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, streamCheckpointer, acknowledgementSetManager); + final AcknowledgementSetManager acknowledgementSetManager, + final CascadingActionDetector cascadeActionDetector) { + return new BinlogEventListener(streamPartition, buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, streamCheckpointer, acknowledgementSetManager, cascadeActionDetector); } @Override @@ -194,6 +214,11 @@ void handleTableMapEvent(com.github.shyiko.mysql.binlog.event.Event event) { void handleInsertEvent(com.github.shyiko.mysql.binlog.event.Event event) { LOG.debug("Handling insert event"); final WriteRowsEventData data = event.getData(); + + if (!isValidTableId(data.getTableId())) { + return; + } + handleRowChangeEvent(event, data.getTableId(), data.getRows(), OpenSearchBulkActions.INDEX); } @@ -201,6 +226,13 @@ void handleUpdateEvent(com.github.shyiko.mysql.binlog.event.Event event) { LOG.debug("Handling update event"); final UpdateRowsEventData data = event.getData(); + if (!isValidTableId(data.getTableId())) { + return; + } + + // Check if a cascade action is involved + cascadeActionDetector.detectCascadingUpdates(event, parentTableMap, tableMetadataMap.get(data.getTableId())); + // updatedRow contains data before update as key and data after update as value final List rows = data.getRows().stream() .map(Map.Entry::getValue) @@ -213,9 +245,30 @@ void handleDeleteEvent(com.github.shyiko.mysql.binlog.event.Event event) { LOG.debug("Handling delete event"); final DeleteRowsEventData data = event.getData(); + if (!isValidTableId(data.getTableId())) { + return; + } + + // Check if a cascade action is involved + cascadeActionDetector.detectCascadingDeletes(event, parentTableMap, tableMetadataMap.get(data.getTableId())); + handleRowChangeEvent(event, data.getTableId(), data.getRows(), OpenSearchBulkActions.DELETE); } + private boolean isValidTableId(long tableId) { + if (!tableMetadataMap.containsKey(tableId)) { + LOG.debug("Cannot find table metadata, the event is likely not from a table of interest or the table metadata was not read"); + return false; + } + + if (!isTableOfInterest(tableMetadataMap.get(tableId).getFullTableName())) { + LOG.debug("The event is not from a table of interest"); + return false; + } + + return true; + } + private void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event, long tableId, List rows, @@ -236,16 +289,7 @@ private void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event eve final long bytes = event.toString().getBytes().length; bytesReceivedSummary.record(bytes); - if (!tableMetadataMap.containsKey(tableId)) { - LOG.debug("Cannot find table metadata, the event is likely not from a table of interest or the table metadata was not read"); - return; - } final TableMetadata tableMetadata = tableMetadataMap.get(tableId); - final String fullTableName = tableMetadata.getFullTableName(); - if (!isTableOfInterest(fullTableName)) { - LOG.debug("The event is not from a table of interest"); - return; - } final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = event.getHeader().getTimestamp(); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java index 270a4f3fbf..16c8fc94b2 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java @@ -16,6 +16,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,8 +114,10 @@ public void shutdown() { private void refreshTask(RdsSourceConfig sourceConfig) { final BinaryLogClient binaryLogClient = binlogClientFactory.create(); + final CascadingActionDetector cascadeActionDetector = new CascadingActionDetector(sourceCoordinator); binaryLogClient.registerEventListener(BinlogEventListener.create( - buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, streamCheckpointer, acknowledgementSetManager)); + streamPartition, buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, + streamCheckpointer, acknowledgementSetManager, cascadeActionDetector)); final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, binaryLogClient, pluginMetrics); executorService.submit(() -> streamWorker.processStream(streamPartition)); } diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactoryTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactoryTest.java index c092a8b48c..4f55c44b98 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactoryTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/coordination/PartitionFactoryTest.java @@ -10,9 +10,12 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.model.source.coordinator.SourcePartitionStoreItem; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.DataFilePartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ExportPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ResyncPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import java.util.UUID; @@ -45,6 +48,36 @@ void given_export_partition_item_then_create_export_partition() { assertThat(objectUnderTest.apply(partitionStoreItem), instanceOf(ExportPartition.class)); } + @Test + void given_stream_partition_item_then_create_stream_partition() { + PartitionFactory objectUnderTest = createObjectUnderTest(); + when(partitionStoreItem.getSourceIdentifier()).thenReturn(UUID.randomUUID() + "|" + StreamPartition.PARTITION_TYPE); + when(partitionStoreItem.getSourcePartitionKey()).thenReturn(UUID.randomUUID().toString()); + when(partitionStoreItem.getPartitionProgressState()).thenReturn(null); + + assertThat(objectUnderTest.apply(partitionStoreItem), instanceOf(StreamPartition.class)); + } + + @Test + void given_datafile_partition_item_then_create_datafile_partition() { + PartitionFactory objectUnderTest = createObjectUnderTest(); + when(partitionStoreItem.getSourceIdentifier()).thenReturn(UUID.randomUUID() + "|" + DataFilePartition.PARTITION_TYPE); + when(partitionStoreItem.getSourcePartitionKey()).thenReturn(UUID.randomUUID() + "|" + UUID.randomUUID() + "|" + UUID.randomUUID()); + when(partitionStoreItem.getPartitionProgressState()).thenReturn(null); + + assertThat(objectUnderTest.apply(partitionStoreItem), instanceOf(DataFilePartition.class)); + } + + @Test + void given_resync_partition_item_then_create_resync_partition() { + PartitionFactory objectUnderTest = createObjectUnderTest(); + when(partitionStoreItem.getSourceIdentifier()).thenReturn(UUID.randomUUID() + "|" + ResyncPartition.PARTITION_TYPE); + when(partitionStoreItem.getSourcePartitionKey()).thenReturn(UUID.randomUUID() + "|" + UUID.randomUUID() + "|12345"); + when(partitionStoreItem.getPartitionProgressState()).thenReturn(null); + + assertThat(objectUnderTest.apply(partitionStoreItem), instanceOf(ResyncPartition.class)); + } + @Test void given_store_item_of_undefined_type_then_create_global_state() { PartitionFactory objectUnderTest = createObjectUnderTest(); diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetectorTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetectorTest.java new file mode 100644 index 0000000000..c17e003782 --- /dev/null +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetectorTest.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.rds.resync; + +import com.github.shyiko.mysql.binlog.event.Event; +import com.github.shyiko.mysql.binlog.event.UpdateRowsEventData; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ResyncPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.ResyncProgressState; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; +import org.opensearch.dataprepper.plugins.source.rds.model.ParentTable; +import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; + +import java.io.Serializable; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER; + +@ExtendWith(MockitoExtension.class) +class CascadingActionDetectorTest { + @Mock + private EnhancedSourceCoordinator sourceCoordinator; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Event event; + + @Mock + private TableMetadata tableMetadata; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private StreamPartition streamPartition; + + private CascadingActionDetector objectUnderTest; + private ForeignKeyRelation foreignKeyRelationWithCascading; + private ForeignKeyRelation foreignKeyRelationWithoutCascading; + private Map parentTableMap; + + @BeforeEach + void setUp() { + objectUnderTest = createObjectUnderTest(); + prepareTestTables(); + } + + @Test + void testGetParentTableMap_returns_empty_list_if_stream_progress_state_is_empty() { + when(streamPartition.getProgressState()).thenReturn(Optional.empty()); + + Map actualParentTableMap = objectUnderTest.getParentTableMap(streamPartition); + + assertThat(actualParentTableMap.size(), is(0)); + } + + @Test + void testGetParentTableMap_returns_only_foreign_relations_with_cascading_actions() { + final StreamProgressState progressState = mock(StreamProgressState.class); + when(streamPartition.getProgressState()).thenReturn(Optional.of(progressState)); + when(progressState.getForeignKeyRelations()).thenReturn(List.of(foreignKeyRelationWithCascading, foreignKeyRelationWithoutCascading)); + + Map actualParentTableMap = objectUnderTest.getParentTableMap(streamPartition); + + assertThat(actualParentTableMap.size(), is(1)); + assertThat(actualParentTableMap.containsKey("test-database.parent-table1"), is(true)); + + final ParentTable parentTable = actualParentTableMap.get("test-database.parent-table1"); + assertThat(parentTable.getDatabaseName(), is("test-database")); + assertThat(parentTable.getTableName(), is("parent-table1")); + assertThat(parentTable.getReferencedColumnMap().size(), is(1)); + assertThat(parentTable.getReferencedColumnMap().containsKey("referenced-column"), is(true)); + assertThat(parentTable.getReferencedColumnMap().get("referenced-column").size(), is(1)); + assertThat(parentTable.getReferencedColumnMap().get("referenced-column").get(0), is(foreignKeyRelationWithCascading)); + } + + @Test + void testDetectCascadingUpdates() { + UpdateRowsEventData data = mock(UpdateRowsEventData.class); + List> rows = List.of(Map.entry(new Serializable[]{"old-value"}, new Serializable[]{"new-value"})); + long timestampInMillis = Instant.now().toEpochMilli(); + List primaryKeys = List.of("primary-key"); + when(event.getData()).thenReturn(data); + when(event.getHeader().getTimestamp()).thenReturn(timestampInMillis); + when(tableMetadata.getFullTableName()).thenReturn("test-database.parent-table1"); + when(data.getRows()).thenReturn(rows); + when(tableMetadata.getColumnNames()).thenReturn(List.of("referenced-column")); + when(tableMetadata.getPrimaryKeys()).thenReturn(primaryKeys); + + objectUnderTest.detectCascadingUpdates(event, parentTableMap, tableMetadata); + + ArgumentCaptor resyncPartitionArgumentCaptor = ArgumentCaptor.forClass(ResyncPartition.class); + verify(sourceCoordinator).createPartition(resyncPartitionArgumentCaptor.capture()); + ResyncPartition resyncPartition = resyncPartitionArgumentCaptor.getValue(); + + assertThat(resyncPartition.getPartitionKey(), is("test-database|child-table|" + timestampInMillis)); + + ResyncProgressState progressState = resyncPartition.getProgressState().get(); + assertThat(progressState.getForeignKeyName(), is("foreign-key1")); + assertThat(progressState.getUpdatedValue(), is("new-value")); + assertThat(progressState.getPrimaryKeys(), is(primaryKeys)); + } + + @Test + void detectCascadingDeletes() { + long timestampInMillis = Instant.now().toEpochMilli(); + List primaryKeys = List.of("primary-key"); + when(event.getHeader().getTimestamp()).thenReturn(timestampInMillis); + when(tableMetadata.getFullTableName()).thenReturn("test-database.parent-table1"); + when(tableMetadata.getPrimaryKeys()).thenReturn(primaryKeys); + + objectUnderTest.detectCascadingDeletes(event, parentTableMap, tableMetadata); + + ArgumentCaptor resyncPartitionArgumentCaptor = ArgumentCaptor.forClass(ResyncPartition.class); + verify(sourceCoordinator).createPartition(resyncPartitionArgumentCaptor.capture()); + ResyncPartition resyncPartition = resyncPartitionArgumentCaptor.getValue(); + + assertThat(resyncPartition.getPartitionKey(), is("test-database|child-table|" + timestampInMillis)); + + ResyncProgressState progressState = resyncPartition.getProgressState().get(); + assertThat(progressState.getForeignKeyName(), is("foreign-key1")); + assertThat(progressState.getUpdatedValue(), nullValue()); + assertThat(progressState.getPrimaryKeys(), is(primaryKeys)); + } + + private CascadingActionDetector createObjectUnderTest() { + return new CascadingActionDetector(sourceCoordinator); + } + + private void prepareTestTables() { + final String databaseName = "test-database"; + final String parentTableName1 = "parent-table1"; + final String parentTableName2 = "parent-table2"; + final String referencedColumnName = "referenced-column"; + final String childTableName = "child-table"; + final String foreignKey1 = "foreign-key1"; + final String foreignKey2 = "foreign-key2"; + + + foreignKeyRelationWithCascading = ForeignKeyRelation.builder() + .databaseName(databaseName) + .parentTableName(parentTableName1) + .referencedKeyName(referencedColumnName) + .childTableName(childTableName) + .foreignKeyName(foreignKey1) + .updateAction(ForeignKeyAction.CASCADE) + .deleteAction(ForeignKeyAction.SET_NULL) + .build(); + + foreignKeyRelationWithoutCascading = ForeignKeyRelation.builder() + .databaseName(databaseName) + .parentTableName(parentTableName2) + .referencedKeyName(referencedColumnName) + .childTableName(childTableName) + .foreignKeyName(foreignKey2) + .updateAction(ForeignKeyAction.RESTRICT) + .build(); + + ParentTable parentTable1 = ParentTable.builder() + .databaseName(databaseName) + .tableName(parentTableName1) + .referencedColumnMap(Map.of(referencedColumnName, List.of(foreignKeyRelationWithCascading))) + .build(); + + parentTableMap = Map.of(databaseName + DOT_DELIMITER + parentTableName1, parentTable1); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java index 98d3874703..ce6af88009 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/schema/SchemaManagerTest.java @@ -12,6 +12,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyAction; +import org.opensearch.dataprepper.plugins.source.rds.model.ForeignKeyRelation; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -30,13 +32,22 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER; import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_FILE; import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_POSITION; import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_STATUS_QUERY; import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.COLUMN_NAME; import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.TYPE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.DELETE_RULE; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.FKCOLUMN_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.FKTABLE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.PKCOLUMN_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.PKTABLE_NAME; +import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.UPDATE_RULE; @ExtendWith(MockitoExtension.class) class SchemaManagerTest { @@ -167,6 +178,45 @@ void getColumnDataTypes_whenColumnsExist_shouldReturnValidMapping() throws SQLEx assertThat(result, equalTo(expectedColumnTypes)); } + @Test + void test_getForeignKeyRelations_returns_foreign_key_relations() throws SQLException { + final String databaseName = "test-db"; + final String tableName = "test-table"; + final List tableNames = List.of(databaseName + DOT_DELIMITER + tableName); + final ResultSet tableResult = mock(ResultSet.class); + final ResultSet foreignKeys = mock(ResultSet.class); + final String fkTableName = UUID.randomUUID().toString(); + final String fkColumnName = UUID.randomUUID().toString(); + final String pkTableName = UUID.randomUUID().toString(); + final String pkColumnName = UUID.randomUUID().toString(); + final DatabaseMetaData metaData = mock(DatabaseMetaData.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.getMetaData()).thenReturn(metaData); + when(metaData.getTables(eq(databaseName), any(), eq(tableName), any())).thenReturn(tableResult); + when(tableResult.next()).thenReturn(true, false); + when(metaData.getImportedKeys(eq(databaseName), any(), eq(tableName))).thenReturn(foreignKeys); + when(foreignKeys.next()).thenReturn(true, false); + when(foreignKeys.getString(FKTABLE_NAME)).thenReturn(fkTableName); + when(foreignKeys.getString(FKCOLUMN_NAME)).thenReturn(fkColumnName); + when(foreignKeys.getString(PKTABLE_NAME)).thenReturn(pkTableName); + when(foreignKeys.getString(PKCOLUMN_NAME)).thenReturn(pkColumnName); + when(foreignKeys.getShort(UPDATE_RULE)).thenReturn((short)DatabaseMetaData.importedKeyCascade); + when(foreignKeys.getShort(DELETE_RULE)).thenReturn((short)DatabaseMetaData.importedKeySetNull); + + final List foreignKeyRelations = schemaManager.getForeignKeyRelations(tableNames); + + assertThat(foreignKeyRelations.size(), is(1)); + + ForeignKeyRelation foreignKeyRelation = foreignKeyRelations.get(0); + assertThat(foreignKeyRelation.getParentTableName(), is(pkTableName)); + assertThat(foreignKeyRelation.getReferencedKeyName(), is(pkColumnName)); + assertThat(foreignKeyRelation.getChildTableName(), is(fkTableName)); + assertThat(foreignKeyRelation.getForeignKeyName(), is(fkColumnName)); + assertThat(foreignKeyRelation.getUpdateAction(), is(ForeignKeyAction.CASCADE)); + assertThat(foreignKeyRelation.getDeleteAction(), is(ForeignKeyAction.SET_NULL)); + } + private SchemaManager createObjectUnderTest() { return new SchemaManager(connectionManager); } diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java index 1312607821..a6f3939acb 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListenerTest.java @@ -26,6 +26,8 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; +import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector; import java.io.IOException; import java.util.UUID; @@ -46,6 +48,9 @@ @ExtendWith(MockitoExtension.class) class BinlogEventListenerTest { + @Mock + private StreamPartition streamPartition; + @Mock private Buffer> buffer; @@ -64,6 +69,9 @@ class BinlogEventListenerTest { @Mock private AcknowledgementSetManager acknowledgementSetManager; + @Mock + private CascadingActionDetector cascadingActionDetector; + @Mock private ExecutorService eventListnerExecutorService; @@ -153,7 +161,8 @@ void test_given_DeleteRows_event_then_calls_correct_handler(EventType eventType) } private BinlogEventListener createObjectUnderTest() { - return new BinlogEventListener(buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, streamCheckpointer, acknowledgementSetManager); + return BinlogEventListener.create(streamPartition, buffer, sourceConfig, s3Prefix, pluginMetrics, binaryLogClient, + streamCheckpointer, acknowledgementSetManager, cascadingActionDetector); } private void verifyHandlerCallHelper() { diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java index 13078e65cb..eab2369648 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresherTest.java @@ -23,6 +23,7 @@ import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; +import org.opensearch.dataprepper.plugins.source.rds.resync.CascadingActionDetector; import java.util.UUID; import java.util.concurrent.ExecutorService; @@ -108,7 +109,9 @@ void test_initialize_then_process_stream() { MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class)) { streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); - binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(buffer), any(RdsSourceConfig.class), any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), eq(acknowledgementSetManager))) + binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), + any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + eq(acknowledgementSetManager), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); } @@ -142,7 +145,9 @@ void test_update_when_credentials_changed_then_refresh_task() { MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class)) { streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); - binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(buffer), any(RdsSourceConfig.class), any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), eq(acknowledgementSetManager))) + binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), + any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + eq(acknowledgementSetManager), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); streamWorkerTaskRefresher.update(sourceConfig2); @@ -175,7 +180,9 @@ void test_update_when_credentials_unchanged_then_do_nothing() { MockedStatic binlogEventListenerMockedStatic = mockStatic(BinlogEventListener.class)) { streamWorkerMockedStatic.when(() -> StreamWorker.create(eq(sourceCoordinator), any(BinaryLogClient.class), eq(pluginMetrics))) .thenReturn(streamWorker); - binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(buffer), any(RdsSourceConfig.class), any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), eq(acknowledgementSetManager))) + binlogEventListenerMockedStatic.when(() -> BinlogEventListener.create(eq(streamPartition), eq(buffer), any(RdsSourceConfig.class), + any(String.class), eq(pluginMetrics), eq(binlogClient), eq(streamCheckpointer), + eq(acknowledgementSetManager), any(CascadingActionDetector.class))) .thenReturn(binlogEventListener); streamWorkerTaskRefresher.initialize(sourceConfig); streamWorkerTaskRefresher.update(sourceConfig);