Skip to content

Commit

Permalink
Several improvements to RDS source (#4810)
Browse files Browse the repository at this point in the history
* Add schema manager to query database

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Get real primary keys for export

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Get binlog start position for stream

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Refactor SnapshotStrategy to RdsApiStrategy

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Update unit tests

Signed-off-by: Hai Yan <oeyh@amazon.com>

* address comments

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Add retry to database queries

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Handle describe exceptions

Signed-off-by: Hai Yan <oeyh@amazon.com>

* Address more comments

Signed-off-by: Hai Yan <oeyh@amazon.com>

---------

Signed-off-by: Hai Yan <oeyh@amazon.com>
  • Loading branch information
oeyh authored Aug 8, 2024
1 parent 642db0d commit 04de9eb
Show file tree
Hide file tree
Showing 25 changed files with 705 additions and 137 deletions.
1 change: 1 addition & 0 deletions data-prepper-plugins/rds-source/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies {
implementation 'com.fasterxml.jackson.core:jackson-databind'

implementation 'com.zendesk:mysql-binlog-connector-java:0.29.2'
implementation 'com.mysql:mysql-connector-j:8.4.0'

testImplementation project(path: ':data-prepper-test-common')
testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
import org.opensearch.dataprepper.model.event.EventFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.source.rds.export.ClusterSnapshotStrategy;
import org.opensearch.dataprepper.plugins.source.rds.export.DataFileScheduler;
import org.opensearch.dataprepper.plugins.source.rds.export.ExportScheduler;
import org.opensearch.dataprepper.plugins.source.rds.export.ExportTaskManager;
import org.opensearch.dataprepper.plugins.source.rds.export.InstanceSnapshotStrategy;
import org.opensearch.dataprepper.plugins.source.rds.export.SnapshotManager;
import org.opensearch.dataprepper.plugins.source.rds.export.SnapshotStrategy;
import org.opensearch.dataprepper.plugins.source.rds.leader.ClusterApiStrategy;
import org.opensearch.dataprepper.plugins.source.rds.leader.InstanceApiStrategy;
import org.opensearch.dataprepper.plugins.source.rds.leader.LeaderScheduler;
import org.opensearch.dataprepper.plugins.source.rds.leader.RdsApiStrategy;
import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata;
import org.opensearch.dataprepper.plugins.source.rds.schema.ConnectionManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;
import org.opensearch.dataprepper.plugins.source.rds.stream.BinlogClientFactory;
import org.opensearch.dataprepper.plugins.source.rds.stream.StreamScheduler;
import org.slf4j.Logger;
Expand Down Expand Up @@ -77,13 +80,16 @@ public RdsService(final EnhancedSourceCoordinator sourceCoordinator,
public void start(Buffer<Record<Event>> buffer) {
LOG.info("Start running RDS service");
final List<Runnable> runnableList = new ArrayList<>();
leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig);

final RdsApiStrategy rdsApiStrategy = sourceConfig.isCluster() ?
new ClusterApiStrategy(rdsClient) : new InstanceApiStrategy(rdsClient);
final DbMetadata dbMetadata = rdsApiStrategy.describeDb(sourceConfig.getDbIdentifier());
leaderScheduler = new LeaderScheduler(
sourceCoordinator, sourceConfig, getSchemaManager(sourceConfig, dbMetadata), dbMetadata);
runnableList.add(leaderScheduler);

if (sourceConfig.isExportEnabled()) {
final SnapshotStrategy snapshotStrategy = sourceConfig.isCluster() ?
new ClusterSnapshotStrategy(rdsClient) : new InstanceSnapshotStrategy(rdsClient);
final SnapshotManager snapshotManager = new SnapshotManager(snapshotStrategy);
final SnapshotManager snapshotManager = new SnapshotManager(rdsApiStrategy);
final ExportTaskManager exportTaskManager = new ExportTaskManager(rdsClient);
exportScheduler = new ExportScheduler(
sourceCoordinator, snapshotManager, exportTaskManager, s3Client, pluginMetrics);
Expand All @@ -94,7 +100,7 @@ public void start(Buffer<Record<Event>> buffer) {
}

if (sourceConfig.isStreamEnabled()) {
BinaryLogClient binaryLogClient = new BinlogClientFactory(sourceConfig, rdsClient).create();
BinaryLogClient binaryLogClient = new BinlogClientFactory(sourceConfig, rdsClient, dbMetadata).create();
if (sourceConfig.getTlsConfig() == null || !sourceConfig.getTlsConfig().isInsecure()) {
binaryLogClient.setSSLMode(SSLMode.REQUIRED);
} else {
Expand Down Expand Up @@ -128,4 +134,14 @@ public void shutdown() {
executor.shutdownNow();
}
}

private SchemaManager getSchemaManager(final RdsSourceConfig sourceConfig, final DbMetadata dbMetadata) {
final ConnectionManager connectionManager = new ConnectionManager(
dbMetadata.getHostName(),
dbMetadata.getPort(),
sourceConfig.getAuthenticationConfig().getUsername(),
sourceConfig.getAuthenticationConfig().getPassword(),
!sourceConfig.getTlsConfig().isInsecure());
return new SchemaManager(connectionManager);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Map;

public class DataFileProgressState {

@JsonProperty("isLoaded")
Expand All @@ -21,6 +24,12 @@ public class DataFileProgressState {
@JsonProperty("sourceTable")
private String sourceTable;

/**
* Map of table name to primary keys
*/
@JsonProperty("primaryKeyMap")
private Map<String, List<String>> primaryKeyMap;

@JsonProperty("snapshotTime")
private long snapshotTime;

Expand Down Expand Up @@ -63,4 +72,12 @@ public long getSnapshotTime() {
public void setSnapshotTime(long snapshotTime) {
this.snapshotTime = snapshotTime;
}

public Map<String, List<String>> getPrimaryKeyMap() {
return primaryKeyMap;
}

public void setPrimaryKeyMap(Map<String, List<String>> primaryKeyMap) {
this.primaryKeyMap = primaryKeyMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Map;

/**
* Progress state for an EXPORT partition
Expand All @@ -32,6 +33,12 @@ public class ExportProgressState {
@JsonProperty("tables")
private List<String> tables;

/**
* Map of table name to primary keys
*/
@JsonProperty("primaryKeyMap")
private Map<String, List<String>> primaryKeyMap;

@JsonProperty("kmsKeyId")
private String kmsKeyId;

Expand Down Expand Up @@ -89,6 +96,14 @@ public void setTables(List<String> tables) {
this.tables = tables;
}

public Map<String, List<String>> getPrimaryKeyMap() {
return primaryKeyMap;
}

public void setPrimaryKeyMap(Map<String, List<String>> primaryKeyMap) {
this.primaryKeyMap = primaryKeyMap;
}

public String getKmsKeyId() {
return kmsKeyId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ public void run() {

DataFileProgressState progressState = dataFilePartition.getProgressState().get();

// TODO: primary key to be obtained by querying database schema
final String primaryKeyName = "id";
final String fullTableName = progressState.getSourceDatabase() + "." + progressState.getSourceTable();
final List<String> primaryKeys = progressState.getPrimaryKeyMap().getOrDefault(fullTableName, List.of());

final long snapshotTime = progressState.getSnapshotTime();
final long eventVersionNumber = snapshotTime - VERSION_OVERLAP_TIME_FOR_EXPORT.toMillis();
Expand All @@ -103,13 +103,14 @@ public void run() {
record,
progressState.getSourceDatabase(),
progressState.getSourceTable(),
List.of(primaryKeyName),
primaryKeys,
snapshotTime,
eventVersionNumber));
bufferAccumulator.add(transformedRecord);
eventCount.getAndIncrement();
bytesProcessedSummary.record(bytes);
} catch (Exception e) {
LOG.error("Failed to process record from object s3://{}/{}", bucket, objectKey, e);
throw new RuntimeException(e);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -189,12 +190,15 @@ private SnapshotInfo checkSnapshotStatus(String snapshotId, Duration timeout) {
LOG.debug("Start checking status of snapshot {}", snapshotId);
while (Instant.now().isBefore(endTime)) {
SnapshotInfo snapshotInfo = snapshotManager.checkSnapshotStatus(snapshotId);
String status = snapshotInfo.getStatus();
// Valid snapshot statuses are: available, copying, creating
// The status should never be "copying" here
if (SnapshotStatus.AVAILABLE.getStatusName().equals(status)) {
LOG.info("Snapshot {} is available.", snapshotId);
return snapshotInfo;

if (snapshotInfo != null) {
String status = snapshotInfo.getStatus();
// Valid snapshot statuses are: available, copying, creating
// The status should never be "copying" here
if (SnapshotStatus.AVAILABLE.getStatusName().equals(status)) {
LOG.info("Snapshot {} is available.", snapshotId);
return snapshotInfo;
}
}

LOG.debug("Snapshot {} is still creating. Wait and check later", snapshotId);
Expand Down Expand Up @@ -272,7 +276,7 @@ private BiConsumer<String, Throwable> completeExport(ExportPartition exportParti

// Create data file partitions for processing S3 files
List<String> dataFileObjectKeys = getDataFileObjectKeys(bucket, prefix, exportTaskId);
createDataFilePartitions(bucket, exportTaskId, dataFileObjectKeys, snapshotTime);
createDataFilePartitions(bucket, exportTaskId, dataFileObjectKeys, snapshotTime, state.getPrimaryKeyMap());

completeExportPartition(exportPartition);
}
Expand Down Expand Up @@ -301,7 +305,11 @@ private List<String> getDataFileObjectKeys(String bucket, String prefix, String
return objectKeys;
}

private void createDataFilePartitions(String bucket, String exportTaskId, List<String> dataFileObjectKeys, long snapshotTime) {
private void createDataFilePartitions(String bucket,
String exportTaskId,
List<String> dataFileObjectKeys,
long snapshotTime,
Map<String, List<String>> primaryKeyMap) {
LOG.info("Total of {} data files generated for export {}", dataFileObjectKeys.size(), exportTaskId);
AtomicInteger totalFiles = new AtomicInteger();
for (final String objectKey : dataFileObjectKeys) {
Expand All @@ -313,6 +321,7 @@ private void createDataFilePartitions(String bucket, String exportTaskId, List<S
progressState.setSourceDatabase(database);
progressState.setSourceTable(table);
progressState.setSnapshotTime(snapshotTime);
progressState.setPrimaryKeyMap(primaryKeyMap);

DataFilePartition dataFilePartition = new DataFilePartition(exportTaskId, bucket, objectKey, Optional.of(progressState));
sourceCoordinator.createPartition(dataFilePartition);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

package org.opensearch.dataprepper.plugins.source.rds.export;

import org.opensearch.dataprepper.plugins.source.rds.leader.RdsApiStrategy;
import org.opensearch.dataprepper.plugins.source.rds.model.SnapshotInfo;

import java.util.UUID;

public class SnapshotManager {
private final SnapshotStrategy snapshotStrategy;
private final RdsApiStrategy snapshotStrategy;

public SnapshotManager(final SnapshotStrategy snapshotStrategy) {
public SnapshotManager(final RdsApiStrategy snapshotStrategy) {
this.snapshotStrategy = snapshotStrategy;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,48 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.rds.export;
package org.opensearch.dataprepper.plugins.source.rds.leader;

import org.opensearch.dataprepper.plugins.source.rds.model.DbMetadata;
import org.opensearch.dataprepper.plugins.source.rds.model.SnapshotInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.rds.RdsClient;
import software.amazon.awssdk.services.rds.model.CreateDbClusterSnapshotRequest;
import software.amazon.awssdk.services.rds.model.CreateDbClusterSnapshotResponse;
import software.amazon.awssdk.services.rds.model.DBCluster;
import software.amazon.awssdk.services.rds.model.DescribeDbClusterSnapshotsRequest;
import software.amazon.awssdk.services.rds.model.DescribeDbClusterSnapshotsResponse;
import software.amazon.awssdk.services.rds.model.DescribeDbClustersRequest;
import software.amazon.awssdk.services.rds.model.DescribeDbClustersResponse;

import java.time.Instant;

/**
* This snapshot strategy works with RDS/Aurora Clusters
*/
public class ClusterSnapshotStrategy implements SnapshotStrategy {
private static final Logger LOG = LoggerFactory.getLogger(ClusterSnapshotStrategy.class);
public class ClusterApiStrategy implements RdsApiStrategy {
private static final Logger LOG = LoggerFactory.getLogger(ClusterApiStrategy.class);
private final RdsClient rdsClient;

public ClusterSnapshotStrategy(final RdsClient rdsClient) {
public ClusterApiStrategy(final RdsClient rdsClient) {
this.rdsClient = rdsClient;
}

@Override
public DbMetadata describeDb(String dbIdentifier) {
final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder()
.dbClusterIdentifier(dbIdentifier)
.build();
try {
final DescribeDbClustersResponse response = rdsClient.describeDBClusters(request);
final DBCluster dbCluster = response.dbClusters().get(0);
return new DbMetadata(dbIdentifier, dbCluster.endpoint(), dbCluster.port());
} catch (Exception e) {
throw new RuntimeException("Failed to describe DB " + dbIdentifier, e);
}
}

@Override
public SnapshotInfo createSnapshot(String dbIdentifier, String snapshotId) {
CreateDbClusterSnapshotRequest request = CreateDbClusterSnapshotRequest.builder()
Expand Down Expand Up @@ -54,11 +72,15 @@ public SnapshotInfo describeSnapshot(String snapshotId) {
.dbClusterSnapshotIdentifier(snapshotId)
.build();

DescribeDbClusterSnapshotsResponse response = rdsClient.describeDBClusterSnapshots(request);
String snapshotArn = response.dbClusterSnapshots().get(0).dbClusterSnapshotArn();
String status = response.dbClusterSnapshots().get(0).status();
Instant createTime = response.dbClusterSnapshots().get(0).snapshotCreateTime();

return new SnapshotInfo(snapshotId, snapshotArn, createTime, status);
try {
DescribeDbClusterSnapshotsResponse response = rdsClient.describeDBClusterSnapshots(request);
String snapshotArn = response.dbClusterSnapshots().get(0).dbClusterSnapshotArn();
String status = response.dbClusterSnapshots().get(0).status();
Instant createTime = response.dbClusterSnapshots().get(0).snapshotCreateTime();
return new SnapshotInfo(snapshotId, snapshotArn, createTime, status);
} catch (Exception e) {
LOG.error("Failed to describe snapshot {}", snapshotId, e);
return null;
}
}
}
Loading

0 comments on commit 04de9eb

Please sign in to comment.