Skip to content

Commit

Permalink
Added unit tests for the SnapshotStates
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Helma <chelma+github@amazon.com>
  • Loading branch information
chelma committed May 2, 2024
1 parent 1e8d1f0 commit e4884d1
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 18 deletions.
14 changes: 9 additions & 5 deletions RFS/src/main/java/com/rfs/common/SnapshotCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public abstract class SnapshotCreator {
private static final ObjectMapper mapper = new ObjectMapper();

private final OpenSearchClient client;
public final String snapshotName;
private final String snapshotName;

public SnapshotCreator(String snapshotName, OpenSearchClient client) {
this.snapshotName = snapshotName;
Expand All @@ -26,6 +26,10 @@ public String getRepoName() {
return "migration_assistant_repo";
}

public String getSnapshotName() {
return snapshotName;
}

public void registerRepo() {
ObjectNode settings = getRequestBodyForRegisterRepo();

Expand Down Expand Up @@ -84,25 +88,25 @@ public boolean isSnapshotFinished() {
}
}

public class RepoRegistrationFailed extends RuntimeException {
public static class RepoRegistrationFailed extends RuntimeException {
public RepoRegistrationFailed(String repoName) {
super("Failed to register repo " + repoName);
}
}

public class SnapshotCreationFailed extends RuntimeException {
public static class SnapshotCreationFailed extends RuntimeException {
public SnapshotCreationFailed(String snapshotName) {
super("Failed to create snapshot " + snapshotName);
}
}

public class SnapshotDoesNotExist extends RuntimeException {
public static class SnapshotDoesNotExist extends RuntimeException {
public SnapshotDoesNotExist(String snapshotName) {
super("Snapshot " + snapshotName + " does not exist");
}
}

public class SnapshotStatusUnparsable extends RuntimeException {
public static class SnapshotStatusUnparsable extends RuntimeException {
public SnapshotStatusUnparsable(String snapshotName) {
super("Status of Snapshot " + snapshotName + " is not parsable");
}
Expand Down
3 changes: 1 addition & 2 deletions RFS/src/main/java/com/rfs/worker/RfsWorker.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.rfs.cms.CmsEntry.Snapshot;
import com.rfs.cms.CmsEntry.SnapshotStatus;
import com.rfs.common.SnapshotCreator;
import com.rfs.common.SnapshotCreator.SnapshotCreationFailed;

public class RfsWorker {
private static final Logger logger = LogManager.getLogger(RfsWorker.class);
Expand All @@ -26,7 +25,7 @@ public void run() throws Exception {

while (true) {
logger.info("Checking if work remains in the Snapshot Phase...");
Snapshot snapshotEntry = cmsClient.getSnapshotEntry(snapshotCreator.snapshotName);
Snapshot snapshotEntry = cmsClient.getSnapshotEntry(snapshotCreator.getSnapshotName());

if (snapshotEntry == null || snapshotEntry.status != SnapshotStatus.COMPLETED) {
WorkerState nextState = new SnapshotState.EnterPhase(globalState, cmsClient, snapshotCreator, snapshotEntry);
Expand Down
18 changes: 11 additions & 7 deletions RFS/src/main/java/com/rfs/worker/SnapshotState.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public CreateEntry(GlobalData globalState, CmsClient cmsClient, SnapshotCreator
@Override
public void run() {
logger.info("Snapshot CMS Entry not found, attempting to create it...");
cmsClient.createSnapshotEntry(snapshotCreator.snapshotName);
cmsClient.createSnapshotEntry(snapshotCreator.getSnapshotName());
logger.info("Snapshot CMS Entry created");
}

Expand All @@ -92,7 +92,7 @@ public void run() {
snapshotCreator.createSnapshot();

logger.info("Snapshot in progress...");
cmsClient.updateSnapshotEntry(snapshotCreator.snapshotName, SnapshotStatus.IN_PROGRESS);
cmsClient.updateSnapshotEntry(snapshotCreator.getSnapshotName(), SnapshotStatus.IN_PROGRESS);
}

@Override
Expand All @@ -108,16 +108,20 @@ public WaitForSnapshot(GlobalData globalState, CmsClient cmsClient, SnapshotCrea
super(globalState, cmsClient, snapshotCreator);
}

protected void waitABit() throws InterruptedException {
logger.info("Snapshot not finished yet; sleeping for 5 seconds...");
Thread.sleep(5000);
}

@Override
public void run() {
try{
while (!snapshotCreator.isSnapshotFinished()) {
logger.info("Snapshot not finished yet; sleeping for 5 seconds...");
Thread.sleep(5000);
waitABit();
}
} catch (InterruptedException e) {
logger.error("Interrupted while waiting for Snapshot to complete", e);
this.e = snapshotCreator.new SnapshotCreationFailed(snapshotCreator.snapshotName);
this.e = new SnapshotCreationFailed(snapshotCreator.getSnapshotName());
} catch (SnapshotCreationFailed e) {
this.e = e;
}
Expand All @@ -140,7 +144,7 @@ public ExitPhaseSuccess(GlobalData globalState, CmsClient cmsClient, SnapshotCre

@Override
public void run() {
cmsClient.updateSnapshotEntry(snapshotCreator.snapshotName, SnapshotStatus.COMPLETED);
cmsClient.updateSnapshotEntry(snapshotCreator.getSnapshotName(), SnapshotStatus.COMPLETED);
globalState.updatePhase(GlobalData.Phase.SNAPSHOT_COMPLETED);
logger.info("Snapshot completed, exiting Snapshot Phase...");
}
Expand All @@ -162,7 +166,7 @@ public ExitPhaseSnapshotFailed(GlobalData globalState, CmsClient cmsClient, Snap
@Override
public void run() {
logger.error("Snapshot creation failed");
cmsClient.updateSnapshotEntry(snapshotCreator.snapshotName, SnapshotStatus.FAILED);
cmsClient.updateSnapshotEntry(snapshotCreator.getSnapshotName(), SnapshotStatus.FAILED);
globalState.updatePhase(GlobalData.Phase.SNAPSHOT_FAILED);
throw e;
}
Expand Down
4 changes: 0 additions & 4 deletions RFS/src/test/java/com/rfs/common/S3RepoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.api.Test;

import org.mockito.Mockito;
Expand Down
225 changes: 225 additions & 0 deletions RFS/src/test/java/com/rfs/worker/SnapshotStateTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package com.rfs.worker;

import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.api.Test;

import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

import com.rfs.cms.CmsClient;
import com.rfs.cms.CmsEntry.Snapshot;
import com.rfs.cms.CmsEntry.SnapshotStatus;
import com.rfs.common.SnapshotCreator;
import com.rfs.common.SnapshotCreator.SnapshotCreationFailed;
import com.rfs.worker.SnapshotState.ExitPhaseSnapshotFailed;
import com.rfs.worker.SnapshotState.ExitPhaseSuccess;
import com.rfs.worker.SnapshotState.WaitForSnapshot;

import static org.mockito.Mockito.*;


@ExtendWith(MockitoExtension.class)
public class SnapshotStateTest {

@Test
void EnterPhase_run_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
Snapshot snapshotEntry = Mockito.mock(Snapshot.class);

// Run the test
SnapshotState.EnterPhase enterPhase = new SnapshotState.EnterPhase(globalState, cmsClient, snapshotCreator, snapshotEntry);
enterPhase.run();

// Check the results
Mockito.verify(globalState, times(1)).updatePhase(GlobalData.Phase.SNAPSHOT_IN_PROGRESS);
}

static Stream<Arguments> provideEnterPhaseNextArgs() {
return Stream.of(
Arguments.of(null, SnapshotState.CreateEntry.class),
Arguments.of(new Snapshot("test", SnapshotStatus.NOT_STARTED), SnapshotState.InitiateSnapshot.class),
Arguments.of(new Snapshot("test", SnapshotStatus.IN_PROGRESS), SnapshotState.WaitForSnapshot.class)
);
}

@ParameterizedTest
@MethodSource("provideEnterPhaseNextArgs")
void EnterPhase_nextState_AsExpected(Snapshot snapshotEntry, Class<?> expected) {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);

// Run the test
SnapshotState.EnterPhase enterPhase = new SnapshotState.EnterPhase(globalState, cmsClient, snapshotCreator, snapshotEntry);
WorkerState nextState = enterPhase.nextState();

// Check the results
assertEquals(expected, nextState.getClass());
}

@Test
void CreateEntry_run_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.getSnapshotName()).thenReturn("test");

// Run the test
SnapshotState.CreateEntry createPhase = new SnapshotState.CreateEntry(globalState, cmsClient, snapshotCreator);
createPhase.run();

// Check the results
Mockito.verify(cmsClient, times(1)).createSnapshotEntry("test");
}

@Test
void CreateEntry_nextState_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);

// Run the test
SnapshotState.CreateEntry createPhase = new SnapshotState.CreateEntry(globalState, cmsClient, snapshotCreator);
WorkerState nextState = createPhase.nextState();

// Check the results
assertEquals(SnapshotState.InitiateSnapshot.class, nextState.getClass());
}

@Test
void InitiateSnapshot_run_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.getSnapshotName()).thenReturn("test");

// Run the test
SnapshotState.InitiateSnapshot initiatePhase = new SnapshotState.InitiateSnapshot(globalState, cmsClient, snapshotCreator);
initiatePhase.run();

// Check the results
Mockito.verify(snapshotCreator, times(1)).registerRepo();
Mockito.verify(snapshotCreator, times(1)).createSnapshot();
Mockito.verify(cmsClient, times(1)).updateSnapshotEntry("test", SnapshotStatus.IN_PROGRESS);
}

@Test
void InitiateSnapshot_nextState_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);

// Run the test
SnapshotState.InitiateSnapshot initiatePhase = new SnapshotState.InitiateSnapshot(globalState, cmsClient, snapshotCreator);
WorkerState nextState = initiatePhase.nextState();

// Check the results
assertEquals(SnapshotState.WaitForSnapshot.class, nextState.getClass());
}

public static class TestableWaitForSnapshot extends WaitForSnapshot {
public TestableWaitForSnapshot(GlobalData globalState, CmsClient cmsClient, SnapshotCreator snapshotCreator) {
super(globalState, cmsClient, snapshotCreator);
}

protected void waitABit() throws InterruptedException {
// Do nothing
}
}

@Test
void WaitForSnapshot_successful_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.isSnapshotFinished())
.thenReturn(false)
.thenReturn(true);

// Run the test
SnapshotState.WaitForSnapshot waitPhase = new TestableWaitForSnapshot(globalState, cmsClient, snapshotCreator);
waitPhase.run();
WorkerState nextState = waitPhase.nextState();

// Check the results
Mockito.verify(snapshotCreator, times(2)).isSnapshotFinished();
assertEquals(SnapshotState.ExitPhaseSuccess.class, nextState.getClass());
}

@Test
void WaitForSnapshot_failedSnapshot_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.isSnapshotFinished())
.thenReturn(false)
.thenThrow(new SnapshotCreationFailed("test"));

// Run the test
SnapshotState.WaitForSnapshot waitPhase = new TestableWaitForSnapshot(globalState, cmsClient, snapshotCreator);
waitPhase.run();
WorkerState nextState = waitPhase.nextState();

// Check the results
Mockito.verify(snapshotCreator, times(2)).isSnapshotFinished();
assertEquals(SnapshotState.ExitPhaseSnapshotFailed.class, nextState.getClass());
}

@Test
void ExitPhaseSuccess_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.getSnapshotName()).thenReturn("test");

// Run the test
SnapshotState.ExitPhaseSuccess exitPhase = new ExitPhaseSuccess(globalState, cmsClient, snapshotCreator);
exitPhase.run();
WorkerState nextState = exitPhase.nextState();

// Check the results
Mockito.verify(cmsClient, times(1)).updateSnapshotEntry("test", SnapshotStatus.COMPLETED);
Mockito.verify(globalState, times(1)).updatePhase(GlobalData.Phase.SNAPSHOT_COMPLETED);
assertEquals(null, nextState);
}

@Test
void ExitPhaseSnapshotFailed_AsExpected() {
// Set up the test
GlobalData globalState = Mockito.mock(GlobalData.class);
CmsClient cmsClient = Mockito.mock(CmsClient.class);
SnapshotCreator snapshotCreator = Mockito.mock(SnapshotCreator.class);
when(snapshotCreator.getSnapshotName()).thenReturn("test");
SnapshotCreationFailed e = new SnapshotCreationFailed("test");

// Run the test
SnapshotState.ExitPhaseSnapshotFailed exitPhase = new ExitPhaseSnapshotFailed(globalState, cmsClient, snapshotCreator, e);
assertThrows(SnapshotCreationFailed.class, () -> {
exitPhase.run();
});

// Check the results
Mockito.verify(cmsClient, times(1)).updateSnapshotEntry("test", SnapshotStatus.FAILED);
Mockito.verify(globalState, times(1)).updatePhase(GlobalData.Phase.SNAPSHOT_FAILED);
}

}

0 comments on commit e4884d1

Please sign in to comment.