diff --git a/app/src/main/java/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidator.java b/app/src/main/java/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidator.java index f55c44b1c..6ad006dd6 100644 --- a/app/src/main/java/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidator.java +++ b/app/src/main/java/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidator.java @@ -14,17 +14,20 @@ /** * This class is used to check the validity of a http request. It has methods that extract the * bearer token, check if the token is empty or null, and if the token is valid. For example, - * expired tokens, empty tokens, or tokens not signed by our private key, will be invalid. + * expired tokens, empty tokens, or tokens not signed by our private key, will be invalid. Tokens + * are cached on first use, and removed if invalid. */ public class AuthRequestValidator { private static final AuthRequestValidator INSTANCE = new AuthRequestValidator(); @Inject private AuthEngine jwtEngine; - @Inject private Cache keyCache; + @Inject Cache keyCache; @Inject private Secrets secrets; @Inject private Logger logger; + String ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment(); + private AuthRequestValidator() {} public static AuthRequestValidator getInstance() { @@ -49,12 +52,12 @@ public boolean isValidAuthenticatedRequest(DomainRequest request) return true; } catch (InvalidTokenException e) { logger.logError("Invalid bearer token!", e); + this.keyCache.remove(ourPublicKey); return false; } } protected String retrievePublicKey() throws SecretRetrievalException { - var ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment(); String key = this.keyCache.get(ourPublicKey); if (key != null) { return key; diff --git a/app/src/test/groovy/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidatorTest.groovy b/app/src/test/groovy/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidatorTest.groovy index 3b20e104d..54b1470e2 100644 --- a/app/src/test/groovy/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidatorTest.groovy +++ b/app/src/test/groovy/gov/hhs/cdc/trustedintermediary/auth/AuthRequestValidatorTest.groovy @@ -6,7 +6,6 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache import gov.hhs.cdc.trustedintermediary.external.jjwt.JjwtEngine import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine import gov.hhs.cdc.trustedintermediary.wrappers.Cache -import gov.hhs.cdc.trustedintermediary.wrappers.InvalidTokenException import gov.hhs.cdc.trustedintermediary.wrappers.Secrets import spock.lang.Specification @@ -227,21 +226,17 @@ class AuthRequestValidatorTest extends Specification{ def validator = AuthRequestValidator.getInstance() def token = "fake-token-here" def header = Map.of("Authorization", "Bearer " + token) - def mockEngine = Mock(JjwtEngine) - def mockCache = Mock(KeyCache) def request = new DomainRequest() def expected = false - TestApplicationContext.register(Cache, mockCache) - TestApplicationContext.register(AuthEngine, mockEngine) + TestApplicationContext.register(Cache, KeyCache.getInstance()) TestApplicationContext.injectRegisteredImplementations() when: request.setHeaders(header) - mockCache.get(_ as String) >> {"my-fake-private-key"} - mockEngine.validateToken(_ as String, _ as String) >> { throw new InvalidTokenException(new Throwable("fake exception"))} def actual = validator.isValidAuthenticatedRequest(request) then: actual == expected + validator.keyCache.get(validator.ourPublicKey) == null } } diff --git a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCache.java b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCache.java index 66b929ab4..28ae90b3d 100644 --- a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCache.java +++ b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCache.java @@ -30,4 +30,9 @@ public void put(String key, String value) { public String get(String key) { return keys.get(key); } + + @Override + public void remove(String key) { + keys.remove(key); + } } diff --git a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/wrappers/Cache.java b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/wrappers/Cache.java index a659f8b5e..946c2dfd1 100644 --- a/shared/src/main/java/gov/hhs/cdc/trustedintermediary/wrappers/Cache.java +++ b/shared/src/main/java/gov/hhs/cdc/trustedintermediary/wrappers/Cache.java @@ -6,4 +6,6 @@ public interface Cache { void put(String key, String value); String get(String key); + + void remove(String key); } diff --git a/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCacheTest.groovy b/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCacheTest.groovy index a45cdbec6..5666c6fed 100644 --- a/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCacheTest.groovy +++ b/shared/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/inmemory/KeyCacheTest.groovy @@ -43,4 +43,18 @@ class KeyCacheTest extends Specification { keys.values().toSet().size() == 1 // all entries have same value, threads had to wait on the lock } + + def "keyCache removal works"() { + given: + def cache = KeyCache.getInstance() + def value = "fake_key" + def key = "report_stream" + def expected = null + when: + cache.put(key, value) + cache.remove(key) + def actual = cache.get(key) + then: + actual == expected + } }