Skip to content

Commit

Permalink
Merge pull request #1333 from CDCgov/story/1326/no-cache-bad-key
Browse files Browse the repository at this point in the history
Remove Invalid Tokens from cache when found
  • Loading branch information
saquino0827 authored Sep 16, 2024
2 parents 108c8c4 + 1e26b25 commit 2ef6e94
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
/**
* This class is used to check the validity of a http request. It has methods that extract the
* bearer token, check if the token is empty or null, and if the token is valid. For example,
* expired tokens, empty tokens, or tokens not signed by our private key, will be invalid.
* expired tokens, empty tokens, or tokens not signed by our private key, will be invalid. Tokens
* are cached on first use, and removed if invalid.
*/
public class AuthRequestValidator {

private static final AuthRequestValidator INSTANCE = new AuthRequestValidator();

@Inject private AuthEngine jwtEngine;
@Inject private Cache keyCache;
@Inject Cache keyCache;
@Inject private Secrets secrets;
@Inject private Logger logger;

String ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment();

private AuthRequestValidator() {}

public static AuthRequestValidator getInstance() {
Expand All @@ -49,12 +52,12 @@ public boolean isValidAuthenticatedRequest(DomainRequest request)
return true;
} catch (InvalidTokenException e) {
logger.logError("Invalid bearer token!", e);
this.keyCache.remove(ourPublicKey);
return false;
}
}

protected String retrievePublicKey() throws SecretRetrievalException {
var ourPublicKey = "trusted-intermediary-public-key-" + ApplicationContext.getEnvironment();
String key = this.keyCache.get(ourPublicKey);
if (key != null) {
return key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache
import gov.hhs.cdc.trustedintermediary.external.jjwt.JjwtEngine
import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine
import gov.hhs.cdc.trustedintermediary.wrappers.Cache
import gov.hhs.cdc.trustedintermediary.wrappers.InvalidTokenException
import gov.hhs.cdc.trustedintermediary.wrappers.Secrets
import spock.lang.Specification

Expand Down Expand Up @@ -227,21 +226,17 @@ class AuthRequestValidatorTest extends Specification{
def validator = AuthRequestValidator.getInstance()
def token = "fake-token-here"
def header = Map.of("Authorization", "Bearer " + token)
def mockEngine = Mock(JjwtEngine)
def mockCache = Mock(KeyCache)
def request = new DomainRequest()
def expected = false
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(AuthEngine, mockEngine)
TestApplicationContext.register(Cache, KeyCache.getInstance())
TestApplicationContext.injectRegisteredImplementations()

when:
request.setHeaders(header)
mockCache.get(_ as String) >> {"my-fake-private-key"}
mockEngine.validateToken(_ as String, _ as String) >> { throw new InvalidTokenException(new Throwable("fake exception"))}
def actual = validator.isValidAuthenticatedRequest(request)

then:
actual == expected
validator.keyCache.get(validator.ourPublicKey) == null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,9 @@ public void put(String key, String value) {
public String get(String key) {
return keys.get(key);
}

@Override
public void remove(String key) {
keys.remove(key);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ public interface Cache {
void put(String key, String value);

String get(String key);

void remove(String key);
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,18 @@ class KeyCacheTest extends Specification {
keys.values().toSet().size() == 1 // all entries have same value, threads had to wait on the lock

}

def "keyCache removal works"() {
given:
def cache = KeyCache.getInstance()
def value = "fake_key"
def key = "report_stream"
def expected = null
when:
cache.put(key, value)
cache.remove(key)
def actual = cache.get(key)
then:
actual == expected
}
}

0 comments on commit 2ef6e94

Please sign in to comment.