Skip to content

Commit

Permalink
Add validation in the checkpointer as well
Browse files Browse the repository at this point in the history
  • Loading branch information
pfifer committed Sep 20, 2018
1 parent 5011c51 commit bb638f1
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
*/
package software.amazon.kinesis.checkpoint;

import java.util.Optional;

import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
Expand All @@ -36,7 +38,6 @@
* The Amazon Kinesis Client Library will instantiate an object and provide a reference to the application
* ShardRecordProcessor instance. Amazon Kinesis Client Library will create one instance per shard assignment.
*/
@RequiredArgsConstructor
@Slf4j
public class ShardRecordProcessorCheckpointer implements RecordProcessorCheckpointer {
@NonNull
Expand All @@ -45,6 +46,18 @@ public class ShardRecordProcessorCheckpointer implements RecordProcessorCheckpoi
@Getter @Accessors(fluent = true)
private final Checkpointer checkpointer;

private final SequenceNumberValidator sequenceNumberValidator;

public ShardRecordProcessorCheckpointer(ShardInfo shardInfo, Checkpointer checkpointer) {
this(shardInfo, checkpointer, new SequenceNumberValidator());
}

public ShardRecordProcessorCheckpointer(ShardInfo shardInfo, Checkpointer checkpointer, SequenceNumberValidator sequenceNumberValidator) {
this.shardInfo = shardInfo;
this.checkpointer = checkpointer;
this.sequenceNumberValidator = sequenceNumberValidator;
}

// Set to the last value set via checkpoint().
// Sample use: verify application shutdown() invoked checkpoint() at the end of a shard.
@Getter @Accessors(fluent = true)
Expand Down Expand Up @@ -107,6 +120,22 @@ public synchronized void checkpoint(String sequenceNumber, long subSequenceNumbe
+ subSequenceNumber);
}

if (sequenceNumberValidator != null) {
Optional<Boolean> validationResult = sequenceNumberValidator.validateSequenceNumberForShard(sequenceNumber,
shardInfo.shardId());
if (!validationResult.isPresent()) {
log.error("[{}] Unable to extract shardId from {}", shardInfo.shardId(), sequenceNumber);
throw new IllegalArgumentException("Unable to extract shardId from " + sequenceNumber);
}
if (!validationResult.get()) {
String seqShardId = sequenceNumberValidator.shardIdFor(sequenceNumber).orElse("MissingShard");
log.error("[{}] Sequence number {} encodes a different shard {}", shardInfo.shardId(), sequenceNumber, seqShardId);
throw new IllegalArgumentException(String.format(
"Sequence number %s encodes a different shard: %s than the expected shard: %s", sequenceNumber,
seqShardId, shardInfo.shardId()));
}
}

/*
* If there isn't a last checkpoint value, we only care about checking the upper bound.
* If there is a last checkpoint value, we want to check both the lower and upper bound.
Expand Down Expand Up @@ -247,7 +276,7 @@ void advancePosition(ExtendedSequenceNumber extendedSequenceNumber)
// just checkpoint at SHARD_END
checkpointToRecord = ExtendedSequenceNumber.SHARD_END;
}

// Don't checkpoint a value we already successfully checkpointed
if (extendedSequenceNumber != null && !extendedSequenceNumber.equals(lastCheckpointValue)) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map.Entry;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.runners.MockitoJUnitRunner;

Expand Down Expand Up @@ -67,8 +71,7 @@ public void setup() throws Exception {
@Test
public final void testCheckpoint() throws Exception {
// First call to checkpoint
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.largestPermittedCheckpointValue(startingExtendedSequenceNumber);
processingCheckpointer.checkpoint();
assertThat(checkpoint.getCheckpoint(shardId), equalTo(startingExtendedSequenceNumber));
Expand All @@ -92,7 +95,7 @@ private Record makeRecord(String seqNum) {
@Test
public final void testCheckpointRecord() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025");
Record record = makeRecord("5025");
Expand All @@ -108,7 +111,7 @@ public final void testCheckpointRecord() throws Exception {
@Test
public final void testCheckpointSubRecord() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030");
Record record = makeRecord("5030");
Expand All @@ -125,7 +128,7 @@ public final void testCheckpointSubRecord() throws Exception {
@Test
public final void testCheckpointSequenceNumber() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035");
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -140,7 +143,7 @@ public final void testCheckpointSequenceNumber() throws Exception {
@Test
public final void testCheckpointExtendedSequenceNumber() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040");
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -154,7 +157,7 @@ public final void testCheckpointExtendedSequenceNumber() throws Exception {
@Test
public final void testCheckpointAtShardEnd() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = ExtendedSequenceNumber.SHARD_END;
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -171,7 +174,7 @@ public final void testCheckpointAtShardEnd() throws Exception {
public final void testPrepareCheckpoint() throws Exception {
// First call to checkpoint
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);

ExtendedSequenceNumber sequenceNumber1 = new ExtendedSequenceNumber("5001");
Expand Down Expand Up @@ -202,7 +205,7 @@ public final void testPrepareCheckpoint() throws Exception {
@Test
public final void testPrepareCheckpointRecord() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5025");
Record record = makeRecord("5025");
Expand All @@ -227,7 +230,7 @@ public final void testPrepareCheckpointRecord() throws Exception {
@Test
public final void testPrepareCheckpointSubRecord() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5030");
Record record = makeRecord("5030");
Expand All @@ -252,7 +255,7 @@ public final void testPrepareCheckpointSubRecord() throws Exception {
*/
@Test
public final void testPrepareCheckpointSequenceNumber() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5035");
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -275,7 +278,7 @@ public final void testPrepareCheckpointSequenceNumber() throws Exception {
*/
@Test
public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = new ExtendedSequenceNumber("5040");
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -297,7 +300,7 @@ public final void testPrepareCheckpointExtendedSequenceNumber() throws Exception
*/
@Test
public final void testPrepareCheckpointAtShardEnd() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
ExtendedSequenceNumber extendedSequenceNumber = ExtendedSequenceNumber.SHARD_END;
processingCheckpointer.largestPermittedCheckpointValue(extendedSequenceNumber);
Expand All @@ -320,7 +323,7 @@ public final void testPrepareCheckpointAtShardEnd() throws Exception {
*/
@Test
public final void testMultipleOutstandingCheckpointersHappyCase() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
processingCheckpointer.largestPermittedCheckpointValue(new ExtendedSequenceNumber("6040"));

Expand Down Expand Up @@ -351,7 +354,7 @@ public final void testMultipleOutstandingCheckpointersHappyCase() throws Excepti
*/
@Test
public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
processingCheckpointer.setInitialCheckpointValue(startingExtendedSequenceNumber);
processingCheckpointer.largestPermittedCheckpointValue(new ExtendedSequenceNumber("7040"));

Expand Down Expand Up @@ -386,7 +389,7 @@ public final void testMultipleOutstandingCheckpointersOutOfOrder() throws Except
*/
@Test
public final void testUpdate() throws Exception {
ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer checkpointer = make(shardInfo, checkpoint);

ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("10");
checkpointer.largestPermittedCheckpointValue(sequenceNumber);
Expand All @@ -404,7 +407,7 @@ public final void testUpdate() throws Exception {
*/
@Test
public final void testClientSpecifiedCheckpoint() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);

// Several checkpoints we're gonna hit
ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2");
Expand Down Expand Up @@ -485,7 +488,7 @@ public final void testClientSpecifiedCheckpoint() throws Exception {
*/
@Test
public final void testClientSpecifiedTwoPhaseCheckpoint() throws Exception {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);

// Several checkpoints we're gonna hit
ExtendedSequenceNumber tooSmall = new ExtendedSequenceNumber("2");
Expand Down Expand Up @@ -585,6 +588,56 @@ public final void testClientSpecifiedTwoPhaseCheckpoint() throws Exception {
checkpoint.getCheckpointObject(shardId).pendingCheckpoint(), equalTo(ExtendedSequenceNumber.SHARD_END));
}

@Rule
public ExpectedException thrown = ExpectedException.none();

@Test
public void testSequenceNumberForDifferentShard() throws Exception {
String sequenceNumber = "49587497311274533994574834252742144236107130636007899138";
String actualShardId = "shardId-000000000000";
String shardId = "shardId-000000000001";

String expectedMessage = String.format(
"Sequence number %s encodes a different shard: %s than the expected shard: %s", sequenceNumber,
actualShardId, shardId);

thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(expectedMessage);

ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer(
new ShardInfo(shardId, "", Collections.emptyList(), ExtendedSequenceNumber.TRIM_HORIZON), checkpoint);

ExtendedSequenceNumber ex = new ExtendedSequenceNumber(
new BigInteger(sequenceNumber, 10).add(BigInteger.ONE).toString());
checkpointer.largestPermittedCheckpointValue(ex);

checkpointer.checkpoint(sequenceNumber, 0);
}

@Test
public void testInvalidSequenceNumberThrows() throws Exception {
String sequenceNumber = "79587497311274533994574834252742144236107130636007899138";

String expectedMessage = "Unable to extract shardId from " + sequenceNumber;

thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(expectedMessage);

ShardRecordProcessorCheckpointer checkpointer = new ShardRecordProcessorCheckpointer(
new ShardInfo(shardId, "", Collections.emptyList(), ExtendedSequenceNumber.TRIM_HORIZON), checkpoint);

ExtendedSequenceNumber ex = new ExtendedSequenceNumber(
new BigInteger(sequenceNumber, 10).add(BigInteger.ONE).toString());
checkpointer.largestPermittedCheckpointValue(ex);

checkpointer.checkpoint(sequenceNumber, 0);

}

private ShardRecordProcessorCheckpointer make(ShardInfo shardInfo, Checkpointer checkpoint) {
return new ShardRecordProcessorCheckpointer(shardInfo, checkpoint, null);
}

private enum CheckpointAction {
NONE, NO_SEQUENCE_NUMBER, WITH_SEQUENCE_NUMBER;
}
Expand All @@ -605,7 +658,7 @@ private enum CheckpointerType {
@Test
public final void testMixedCheckpointCalls() throws Exception {
for (LinkedHashMap<String, CheckpointAction> testPlan : getMixedCallsTestPlan()) {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.CHECKPOINTER);
}
}
Expand All @@ -621,7 +674,7 @@ public final void testMixedCheckpointCalls() throws Exception {
@Test
public final void testMixedTwoPhaseCheckpointCalls() throws Exception {
for (LinkedHashMap<String, CheckpointAction> testPlan : getMixedCallsTestPlan()) {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARED_CHECKPOINTER);
}
}
Expand All @@ -638,7 +691,7 @@ public final void testMixedTwoPhaseCheckpointCalls() throws Exception {
@Test
public final void testMixedTwoPhaseCheckpointCalls2() throws Exception {
for (LinkedHashMap<String, CheckpointAction> testPlan : getMixedCallsTestPlan()) {
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);
testMixedCheckpointCalls(processingCheckpointer, testPlan, CheckpointerType.PREPARE_THEN_CHECKPOINTER);
}
}
Expand Down Expand Up @@ -790,7 +843,7 @@ private void testMixedCheckpointCalls(ShardRecordProcessorCheckpointer processin
public final void testUnsetMetricsScopeDuringCheckpointing() throws Exception {
// First call to checkpoint
ShardRecordProcessorCheckpointer processingCheckpointer =
new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
make(shardInfo, checkpoint);
ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019");
processingCheckpointer.largestPermittedCheckpointValue(sequenceNumber);
processingCheckpointer.checkpoint();
Expand All @@ -800,7 +853,7 @@ public final void testUnsetMetricsScopeDuringCheckpointing() throws Exception {
@Test
public final void testSetMetricsScopeDuringCheckpointing() throws Exception {
// First call to checkpoint
ShardRecordProcessorCheckpointer processingCheckpointer = new ShardRecordProcessorCheckpointer(shardInfo, checkpoint);
ShardRecordProcessorCheckpointer processingCheckpointer = make(shardInfo, checkpoint);

ExtendedSequenceNumber sequenceNumber = new ExtendedSequenceNumber("5019");
processingCheckpointer.largestPermittedCheckpointValue(sequenceNumber);
Expand Down

0 comments on commit bb638f1

Please sign in to comment.