diff --git a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/ShardRecordProcessorCheckpointer.java b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/ShardRecordProcessorCheckpointer.java index ada048340..d5fcf062f 100644 --- a/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/ShardRecordProcessorCheckpointer.java +++ b/amazon-kinesis-client/src/main/java/software/amazon/kinesis/checkpoint/ShardRecordProcessorCheckpointer.java @@ -14,6 +14,8 @@ */ package software.amazon.kinesis.checkpoint; +import java.util.Optional; + import lombok.Getter; import lombok.NonNull; import lombok.RequiredArgsConstructor; @@ -36,7 +38,6 @@ * The Amazon Kinesis Client Library will instantiate an object and provide a reference to the application * ShardRecordProcessor instance. Amazon Kinesis Client Library will create one instance per shard assignment. */ -@RequiredArgsConstructor @Slf4j public class ShardRecordProcessorCheckpointer implements RecordProcessorCheckpointer { @NonNull @@ -45,6 +46,18 @@ public class ShardRecordProcessorCheckpointer implements RecordProcessorCheckpoi @Getter @Accessors(fluent = true) private final Checkpointer checkpointer; + private final SequenceNumberValidator sequenceNumberValidator; + + public ShardRecordProcessorCheckpointer(ShardInfo shardInfo, Checkpointer checkpointer) { + this(shardInfo, checkpointer, new SequenceNumberValidator()); + } + + public ShardRecordProcessorCheckpointer(ShardInfo shardInfo, Checkpointer checkpointer, SequenceNumberValidator sequenceNumberValidator) { + this.shardInfo = shardInfo; + this.checkpointer = checkpointer; + this.sequenceNumberValidator = sequenceNumberValidator; + } + // Set to the last value set via checkpoint(). // Sample use: verify application shutdown() invoked checkpoint() at the end of a shard. @Getter @Accessors(fluent = true) @@ -107,6 +120,22 @@ public synchronized void checkpoint(String sequenceNumber, long subSequenceNumbe + subSequenceNumber); } + if (sequenceNumberValidator != null) { + Optional validationResult = sequenceNumberValidator.validateSequenceNumberForShard(sequenceNumber, + shardInfo.shardId()); + if (!validationResult.isPresent()) { + log.error("[{}] Unable to extract shardId from {}", shardInfo.shardId(), sequenceNumber); + throw new IllegalArgumentException("Unable to extract shardId from " + sequenceNumber); + } + if (!validationResult.get()) { + String seqShardId = sequenceNumberValidator.shardIdFor(sequenceNumber).orElse("MissingShard"); + log.error("[{}] Sequence number {} encodes a different shard {}", shardInfo.shardId(), sequenceNumber, seqShardId); + throw new IllegalArgumentException(String.format( + "Sequence number %s encodes a different shard: %s than the expected shard: %s", sequenceNumber, + seqShardId, shardInfo.shardId())); + } + } + /* * If there isn't a last checkpoint value, we only care about checking the upper bound. * If there is a last checkpoint value, we want to check both the lower and upper bound. @@ -247,7 +276,7 @@ void advancePosition(ExtendedSequenceNumber extendedSequenceNumber) // just checkpoint at SHARD_END checkpointToRecord = ExtendedSequenceNumber.SHARD_END; } - + // Don't checkpoint a value we already successfully checkpointed if (extendedSequenceNumber != null && !extendedSequenceNumber.equals(lastCheckpointValue)) { try { diff --git a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/ShardShardRecordProcessorCheckpointerTest.java b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/ShardShardRecordProcessorCheckpointerTest.java index c46a85725..afe274ee8 100644 --- a/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/ShardShardRecordProcessorCheckpointerTest.java +++ b/amazon-kinesis-client/src/test/java/software/amazon/kinesis/checkpoint/ShardShardRecordProcessorCheckpointerTest.java @@ -19,13 +19,17 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import java.math.BigInteger; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map.Entry; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.runners.MockitoJUnitRunner; @@ -67,8 +71,7 @@ public void setup() throws Exception { @Test public final void testCheckpoint() throws Exception { // First call to checkpoint - ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.largestPermittedCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.checkpoint(); assertThat(checkpoint.getCheckpoint(shardId), equalTo(startingExtendedSequenceNumber)); @@ -92,7 +95,7 @@ private Record makeRecord(String seqNum) { @Test public final void testCheckpointRecord() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025"); Record record = makeRecord("5025"); @@ -108,7 +111,7 @@ public final void testCheckpointRecord() throws Exception { @Test public final void testCheckpointSubRecord() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030"); Record record = makeRecord("5030"); @@ -125,7 +128,7 @@ public final void testCheckpointSubRecord() throws Exception { @Test public final void testCheckpointSequenceNumber() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035"); processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -140,7 +143,7 @@ public final void testCheckpointSequenceNumber() throws Exception { @Test public final void testCheckpointExtendedSequenceNumber() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040"); processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -154,7 +157,7 @@ public final void testCheckpointExtendedSequenceNumber() throws Exception { @Test public final void testCheckpointAtShardEnd() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = ExtendedSequenceNumber.SHARD_END; processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -171,7 +174,7 @@ public final void testCheckpointAtShardEnd() throws Exception { public final void testPrepareCheckpoint() throws Exception { // First call to checkpoint ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber sequenceNumber1 = new ExtendedSequenceNumber("5001"); @@ -202,7 +205,7 @@ public final void testPrepareCheckpoint() throws Exception { @Test public final void testPrepareCheckpointRecord() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025"); Record record = makeRecord("5025"); @@ -227,7 +230,7 @@ public final void testPrepareCheckpointRecord() throws Exception { @Test public final void testPrepareCheckpointSubRecord() throws Exception { ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030"); Record record = makeRecord("5030"); @@ -252,7 +255,7 @@ public final void testPrepareCheckpointSubRecord() throws Exception { */ @Test public final void testPrepareCheckpointSequenceNumber() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035"); processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -275,7 +278,7 @@ public final void testPrepareCheckpointSequenceNumber() throws Exception { */ @Test public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040"); processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -297,7 +300,7 @@ public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception */ @Test public final void testPrepareCheckpointAtShardEnd() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); ExtendedSequenceNumber extendedSequenceNumber = ExtendedSequenceNumber.SHARD_END; processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber); @@ -320,7 +323,7 @@ public final void testPrepareCheckpointAtShardEnd() throws Exception { */ @Test public final void testMultipleOutstandingCheckpointersHappyCase() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.largestPermittedCheckpointValue(new ExtendedSequenceNumber("6040")); @@ -351,7 +354,7 @@ public final void testMultipleOutstandingCheckpointersHappyCase() throws Excepti */ @Test public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber); processingCheckpointer.largestPermittedCheckpointValue(new ExtendedSequenceNumber("7040")); @@ -386,7 +389,7 @@ public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Except */ @Test public final void testUpdate() throws Exception { - ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer checkpointer = make(shardInfo, checkpoint); ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("10"); checkpointer.largestPermittedCheckpointValue(sequenceNumber); @@ -404,7 +407,7 @@ public final void testUpdate() throws Exception { */ @Test public final void testClientSpecifiedCheckpoint() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); // Several checkpoints we're gonna hit ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2"); @@ -485,7 +488,7 @@ public final void testClientSpecifiedCheckpoint() throws Exception { */ @Test public final void testClientSpecifiedTwoPhaseCheckpoint() throws Exception { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); // Several checkpoints we're gonna hit ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2"); @@ -585,6 +588,56 @@ public final void testClientSpecifiedTwoPhaseCheckpoint() throws Exception { checkpoint.getCheckpointObject(shardId).pendingCheckpoint(), equalTo(ExtendedSequenceNumber.SHARD_END)); } + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testSequenceNumberForDifferentShard() throws Exception { + String sequenceNumber = "49587497311274533994574834252742144236107130636007899138"; + String actualShardId = "shardId-000000000000"; + String shardId = "shardId-000000000001"; + + String expectedMessage = String.format( + "Sequence number %s encodes a different shard: %s than the expected shard: %s", sequenceNumber, + actualShardId, shardId); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(expectedMessage); + + ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer( + new ShardInfo(shardId, "", Collections.emptyList(), ExtendedSequenceNumber.TRIM_HORIZON), checkpoint); + + ExtendedSequenceNumber ex = new ExtendedSequenceNumber( + new BigInteger(sequenceNumber, 10).add(BigInteger.ONE).toString()); + checkpointer.largestPermittedCheckpointValue(ex); + + checkpointer.checkpoint(sequenceNumber, 0); + } + + @Test + public void testInvalidSequenceNumberThrows() throws Exception { + String sequenceNumber = "79587497311274533994574834252742144236107130636007899138"; + + String expectedMessage = "Unable to extract shardId from " + sequenceNumber; + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage(expectedMessage); + + ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer( + new ShardInfo(shardId, "", Collections.emptyList(), ExtendedSequenceNumber.TRIM_HORIZON), checkpoint); + + ExtendedSequenceNumber ex = new ExtendedSequenceNumber( + new BigInteger(sequenceNumber, 10).add(BigInteger.ONE).toString()); + checkpointer.largestPermittedCheckpointValue(ex); + + checkpointer.checkpoint(sequenceNumber, 0); + + } + + private ShardRecordProcessorCheckpointer make(ShardInfo shardInfo, Checkpointer checkpoint) { + return new ShardRecordProcessorCheckpointer(shardInfo, checkpoint, null); + } + private enum CheckpointAction { NONE, NO_SEQUENCE_NUMBER, WITH_SEQUENCE_NUMBER; } @@ -605,7 +658,7 @@ private enum CheckpointerType { @Test public final void testMixedCheckpointCalls() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.CHECKPOINTER); } } @@ -621,7 +674,7 @@ public final void testMixedCheckpointCalls() throws Exception { @Test public final void testMixedTwoPhaseCheckpointCalls() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARED_CHECKPOINTER); } } @@ -638,7 +691,7 @@ public final void testMixedTwoPhaseCheckpointCalls() throws Exception { @Test public final void testMixedTwoPhaseCheckpointCalls2() throws Exception { for (LinkedHashMap testPlan : getMixedCallsTestPlan()) { - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARE_THEN_CHECKPOINTER); } } @@ -790,7 +843,7 @@ private void testMixedCheckpointCalls(ShardRecordProcessorCheckpointer processin public final void testUnsetMetricsScopeDuringCheckpointing() throws Exception { // First call to checkpoint ShardRecordProcessorCheckpointer processingCheckpointer = - new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + make(shardInfo, checkpoint); ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019"); processingCheckpointer.largestPermittedCheckpointValue(sequenceNumber); processingCheckpointer.checkpoint(); @@ -800,7 +853,7 @@ public final void testUnsetMetricsScopeDuringCheckpointing() throws Exception { @Test public final void testSetMetricsScopeDuringCheckpointing() throws Exception { // First call to checkpoint - ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint); + ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint); ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019"); processingCheckpointer.largestPermittedCheckpointValue(sequenceNumber);