Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

613: RS Token Cache Improvement #614

Merged
merged 10 commits into from
Oct 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,19 @@ public class ReportStreamOrderSender implements OrderSender {

private static final String OUR_PRIVATE_KEY_ID =
"trusted-intermediary-private-key-" + ApplicationContext.getEnvironment();
private static final String RS_TOKEN_CACHE_ID = "report-stream-token";

private static final String CLIENT_NAME = "flexion.etor-service-sender";
private static final Map<String, String> RS_AUTH_API_HEADERS =
Map.of("Content-Type", "application/x-www-form-urlencoded");

private String rsTokenCache;

protected synchronized String getRsTokenCache() {
return this.rsTokenCache;
}

protected synchronized void setRsTokenCache(String token) {
this.rsTokenCache = token;
}

@Inject private HttpClient client;
@Inject private AuthEngine jwt;
@Inject private Formatter formatter;
@Inject private HapiFhir fhir;
@Inject private Logger logger;
@Inject private Secrets secrets;
@Inject private Cache keyCache;
@Inject private Cache cache;

public static ReportStreamOrderSender getInstance() {
return INSTANCE;
Expand Down Expand Up @@ -93,19 +84,22 @@ protected void logRsSubmissionId(String rsResponseBody) {

protected String getRsToken() throws UnableToSendOrderException {
logger.logInfo("Looking up ReportStream token");
if (getRsTokenCache() != null && isValidToken()) {

var token = cache.get(RS_TOKEN_CACHE_ID);

if (token != null && isValidToken(token)) {
logger.logDebug("valid cache token");
return getRsTokenCache();
return token;
}

String token = requestToken();
setRsTokenCache(token);
token = requestToken();

cache.put(RS_TOKEN_CACHE_ID, token);

return token;
}

protected boolean isValidToken() {
String token = getRsTokenCache();
protected boolean isValidToken(String token) {
LocalDateTime expirationDate = jwt.getExpirationDate(token);

return LocalDateTime.now().isBefore(expirationDate.minus(15, ChronoUnit.SECONDS));
Expand Down Expand Up @@ -164,7 +158,7 @@ protected String requestToken() throws UnableToSendOrderException {
}

protected String retrievePrivateKey() throws SecretRetrievalException {
String key = keyCache.get(OUR_PRIVATE_KEY_ID);
String key = cache.get(OUR_PRIVATE_KEY_ID);
if (key != null) {
return key;
}
Expand All @@ -175,12 +169,12 @@ protected String retrievePrivateKey() throws SecretRetrievalException {
}

void cacheOurPrivateKeyIfNotCachedAlready(String privateKey) {
String key = keyCache.get(OUR_PRIVATE_KEY_ID);
String key = cache.get(OUR_PRIVATE_KEY_ID);
if (key != null) {
return;
}

keyCache.put(OUR_PRIVATE_KEY_ID, privateKey);
cache.put(OUR_PRIVATE_KEY_ID, privateKey);
}

protected String extractToken(String responseBody) throws FormatterProcessingException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache
import gov.hhs.cdc.trustedintermediary.external.jackson.Jackson
import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine
import gov.hhs.cdc.trustedintermediary.wrappers.Cache
import gov.hhs.cdc.trustedintermediary.wrappers.Logger
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException
import gov.hhs.cdc.trustedintermediary.wrappers.HapiFhir
import gov.hhs.cdc.trustedintermediary.wrappers.HttpClient
import gov.hhs.cdc.trustedintermediary.wrappers.HttpClientException
import gov.hhs.cdc.trustedintermediary.wrappers.Logger
import gov.hhs.cdc.trustedintermediary.wrappers.Secrets
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.TypeReference
import java.time.LocalDateTime
import java.time.temporal.ChronoUnit
import spock.lang.Specification

import java.util.concurrent.ConcurrentHashMap

class ReportStreamOrderSenderTest extends Specification {

def setup() {
Expand Down Expand Up @@ -310,61 +308,19 @@ class ReportStreamOrderSenderTest extends Specification {
def "ensure jwt that expires 15 seconds from now is valid"() {
given:
def mockAuthEngine = Mock(AuthEngine)
TestApplicationContext.register(AuthEngine, mockAuthEngine)

mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(20, ChronoUnit.SECONDS)
TestApplicationContext.register(OrderSender, ReportStreamOrderSender.getInstance())

TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.injectRegisteredImplementations()
ReportStreamOrderSender.getInstance().setRsTokenCache("our token from rs")

when:
def isValid = ReportStreamOrderSender.getInstance().isValidToken()
def isValid = ReportStreamOrderSender.getInstance().isValidToken("our token from rs")

then:
isValid
}

def "rsTokenCache getter and setter works, no synchronization"() {
given:
def rsOrderSender = ReportStreamOrderSender.getInstance()
def expected = "fake token"

when:
rsOrderSender.setRsTokenCache(expected)
def actual = rsOrderSender.getRsTokenCache()

then:
actual == expected
}

def "rsTokenCache synchronization works"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def threadNums = 5
def iterations = 25
def table = new ConcurrentHashMap<String, Integer>()

when:
List<Thread> threads = []
(1..threadNums).each { threadId ->
threads.add(new Thread({
for(int i=0; i<iterations; i++) {
orderSender.setRsTokenCache("${i}")
if (i == 24) {
table.put("thread"+"${threadId}", i)
}
}
}))
}

threads*.start()
threads*.join()

then:
orderSender.getRsTokenCache() == "${iterations - 1}"
table.size() == threadNums
table.values().toSet().size() == 1
}

def "sendRequestBody bombs out due to http exception"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
Expand All @@ -385,97 +341,86 @@ class ReportStreamOrderSenderTest extends Specification {
exception.getCause().getClass() == HttpClientException
}

def "getRsToken when cache is empty"() {
def "getRsToken when cache is empty we call RS to get a new one"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def mockClient = Mock(HttpClient)
def mockAuthEngine = Mock(AuthEngine)
def mockSecrets = Mock(Secrets)
def mockFormatter = Mock(Formatter)
def mockCache = Mock(Cache)

//make the cache empty
mockCache.get(_ as String) >> null

def freshTokenFromRs = "new token"
mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs]

TestApplicationContext.register(Formatter, mockFormatter)
TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.register(AuthEngine, Mock(AuthEngine))
TestApplicationContext.register(HttpClient, mockClient)
TestApplicationContext.register(Secrets, mockSecrets)
mockSecrets.getKey(_ as String) >> "fake private key"
TestApplicationContext.register(OrderSender, orderSender)
TestApplicationContext.injectRegisteredImplementations()
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(Secrets, Mock(Secrets))

mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)
mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
mockClient.post(_ as String, _ as Map, _ as String) >> responseBody
TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
1 * mockClient.post(_, _, _)
token == freshTokenFromRs
}

def "getRsToken when cache token is invalid"() {
def "getRsToken when cache token is invalid we call RS to get a new one"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def mockClient = Mock(HttpClient)
def mockAuthEngine = Mock(AuthEngine)
def mockSecrets = Mock(Secrets)
def mockFormatter = Mock(Formatter)
def mockCache = Mock(Cache)

mockCache.get(_ as String) >> "shouldn't be returned"

//mock the auth engine so that the JWT looks like it is invalid
mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)

def freshTokenFromRs = "new token"
mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs]

TestApplicationContext.register(Formatter, mockFormatter)
TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.register(HttpClient, mockClient)
TestApplicationContext.register(Secrets, mockSecrets)
mockSecrets.getKey(_ as String) >> "fakePrivateKey"
TestApplicationContext.register(OrderSender, orderSender)
TestApplicationContext.injectRegisteredImplementations()
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(Secrets, Mock(Secrets))

mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
mockClient.post(_ as String, _ as Map, _ as String) >> responseBody
orderSender.setRsTokenCache("Invalid Token")
TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
1 * mockClient.post(_, _, _)
token == freshTokenFromRs
}

def "getRsToken when cache token is valid"() {
def "getRsToken when cache token is valid, return that cached token"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
orderSender.setRsTokenCache("valid Token")
TestApplicationContext.register(OrderSender, orderSender)
def mockAuthEngine = Mock(AuthEngine)
def mockCache = Mock(Cache)

def mockFormatter = Mock(Formatter)
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
TestApplicationContext.register(Formatter, mockFormatter)
def cachedRsToken = "DogCow goes Moof!"
mockCache.get(_ as String) >> cachedRsToken

def mockLogFormatter = Mock(Formatter)
mockLogFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> null
TestApplicationContext.register(Formatter, mockLogFormatter)
//mock the auth engine so that the JWT looks valid
mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(60, ChronoUnit.SECONDS)

def mockAuthEngine = Mock(AuthEngine)
mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(25, ChronoUnit.SECONDS)
TestApplicationContext.register(AuthEngine, mockAuthEngine)

def mockClient = Mock(HttpClient)
mockClient.post(_ as String, _ as Map, _ as String) >> """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
TestApplicationContext.register(HttpClient, mockClient)

def mockSecrets = Mock(Secrets)
mockSecrets.getKey(_ as String) >> "fakePrivateKey"
TestApplicationContext.register(Secrets, mockSecrets)
TestApplicationContext.register(Cache, mockCache)

TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
token == cachedRsToken
}

def "logRsSubmissionId logs submissionId if convertJsonToObject is successful"() {
Expand Down