Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace random access with round robin #687

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,46 @@ public final class JWT {

private static final Logger LOG = LoggerFactory.getLogger(JWT.class);

// simple random as its value is just to create entropy
private static final Random RND = new Random();
/**
* Internal holder of keys and usage tracking
*/
private static class KeyRing implements Iterable<JWS> {
private final List<JWS> signers = new ArrayList<>();
private int cnt = 0;

int size() {
return signers.size();
}

JWS get(int pos) {
return signers.get(pos);
}

JWS get() {
int size = signers.size();
switch (size) {
case 0:
return null;
case 1:
return signers.get(0);
default:
return signers.get(cnt++ % size);
}
}

JWS set(int pos, JWS jws) {
return signers.set(pos, jws);
}

boolean add(JWS jws) {
return signers.add(jws);
}

@Override
public Iterator<JWS> iterator() {
return signers.iterator();
}
}

private static final Charset UTF8 = StandardCharsets.UTF_8;

Expand All @@ -54,8 +92,8 @@ public final class JWT {
private MessageDigest nonceDigest;

// keep 2 maps (1 for sing, 1 for verify) this simplifies the lookups
private final Map<String, List<JWS>> SIGN = new ConcurrentHashMap<>();
private final Map<String, List<JWS>> VERIFY = new ConcurrentHashMap<>();
private final Map<String, KeyRing> SIGN = new ConcurrentHashMap<>();
private final Map<String, KeyRing> VERIFY = new ConcurrentHashMap<>();

/**
* Adds a JSON Web Key (rfc7517) to the signature maps.
Expand All @@ -66,14 +104,14 @@ public final class JWT {
public JWT addJWK(JWK jwk) {

if (jwk.use() == null || "sig".equals(jwk.use())) {
List<JWS> current;
KeyRing current;
synchronized (this) {
if (jwk.mac() != null || jwk.publicKey() != null) {
current = VERIFY.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>());
current = VERIFY.computeIfAbsent(jwk.getAlgorithm(), k -> new KeyRing());
addJWK(current, jwk);
}
if (jwk.mac() != null || jwk.privateKey() != null) {
current = SIGN.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>());
current = SIGN.computeIfAbsent(jwk.getAlgorithm(), k -> new KeyRing());
addJWK(current, jwk);
}
}
Expand Down Expand Up @@ -129,21 +167,21 @@ public JWT nonceAlgorithm(String alg) {
return this;
}

private void addJWK(List<JWS> current, JWK jwk) {
private void addJWK(KeyRing keyring, JWK jwk) {
boolean replaced = false;
for (int i = 0; i < current.size(); i++) {
if (current.get(i).jwk().label().equals(jwk.label())) {
for (int i = 0; i < keyring.size(); i++) {
if (keyring.get(i).jwk().label().equals(jwk.label())) {
// replace
LOG.info("replacing JWK with label " + jwk.label());
current.set(i, new JWS(jwk));
keyring.set(i, new JWS(jwk));
replaced = true;
break;
}
}

if (!replaced) {
// non existent, add it!
current.add(new JWS(jwk));
keyring.add(new JWS(jwk));
}
}

Expand Down Expand Up @@ -270,7 +308,7 @@ public JsonObject decode(final String token, boolean full, List<X509CRL> crls) t

// verify signature. `sign` will return base64 string.
if (!unsecure) {
List<JWS> signatures = VERIFY.get(alg);
KeyRing signatures = VERIFY.get(alg);

if (signatures == null || signatures.size() == 0) {
throw new NoSuchKeyIdException(alg);
Expand Down Expand Up @@ -331,14 +369,14 @@ public String sign(JsonObject payload, JWTOptions options) {
final String kid;

if (!unsecure) {
List<JWS> signatures = SIGN.get(algorithm);
KeyRing signatures = SIGN.get(algorithm);

if (signatures == null || signatures.size() == 0) {
throw new RuntimeException("Algorithm not supported/allowed: " + algorithm);
}

// lock the crypto implementation
jws = signatures.get(signatures.size() == 1 ? 0 : RND.nextInt(signatures.size()));
jws = signatures.get();
kid = jws.jwk().getId();
} else {
jws = null;
Expand Down
Loading