Skip to content

Commit

Permalink
Add check for domain and app names in BidStreamClient
Browse files Browse the repository at this point in the history
  • Loading branch information
asloobq committed Oct 7, 2024
1 parent 33fab07 commit 92278a7
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 17 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.34</version>
<scope>provided</scope>
</dependency>
</dependencies>

<dependencyManagement>
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/com/uid2/client/BidstreamClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ public BidstreamClient(String baseUrl, String clientApiKey, String base64SecretK
tokenHelper = new TokenHelper(baseUrl, clientApiKey, base64SecretKey);
}

public DecryptionResponse decryptTokenIntoRawUid(String token, String domainNameFromBidRequest) {
return tokenHelper.decrypt(token, Instant.now(), domainNameFromBidRequest, ClientType.BIDSTREAM);
public DecryptionResponse decryptTokenIntoRawUid(String token, String domainOrAppNameFromBidRequest) {
return tokenHelper.decrypt(token, Instant.now(), domainOrAppNameFromBidRequest, ClientType.BIDSTREAM);
}

DecryptionResponse decryptTokenIntoRawUid(String token, String domainNameFromBidRequest, Instant now) {
return tokenHelper.decrypt(token, now, domainNameFromBidRequest, ClientType.BIDSTREAM);
DecryptionResponse decryptTokenIntoRawUid(String token, String domainOrAppNameFromBidRequest, Instant now) {
return tokenHelper.decrypt(token, now, domainOrAppNameFromBidRequest, ClientType.BIDSTREAM);
}

public RefreshResponse refresh() {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/com/uid2/client/DecryptionStatus.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ public enum DecryptionStatus {
/**
* INVALID_TOKEN_LIFETIME: The token has invalid timestamps.
*/
INVALID_TOKEN_LIFETIME
INVALID_TOKEN_LIFETIME,
/**
* DOMAIN_OR_APP_NAME_CHECK_FAILED: The supplied domain name or app name doesn't match with the allowed names of the participant who generated this token
*/
DOMAIN_OR_APP_NAME_CHECK_FAILED
}
17 changes: 16 additions & 1 deletion src/main/java/com/uid2/client/KeyContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class KeyContainer {
private final HashMap<Long, Key> keys = new HashMap<>();
private final HashMap<Integer, List<Key>> keysBySite = new HashMap<>(); //for legacy /key/latest
private final HashMap<Integer, List<Key>> keysByKeyset = new HashMap<>();
private final Map<Integer, Site> siteIdToSite = new HashMap<>();
private Instant latestKeyExpiry;
private int callerSiteId;
private int masterKeysetId;
Expand Down Expand Up @@ -38,7 +39,7 @@ class KeyContainer {
}
}

KeyContainer(int callerSiteId, int masterKeysetId, int defaultKeysetId, long tokenExpirySeconds, List<Key> keyList, IdentityScope identityScope, long maxBidstreamLifetimeSeconds, long maxSharingLifetimeSeconds, long allowClockSkewSeconds) {
KeyContainer(int callerSiteId, int masterKeysetId, int defaultKeysetId, long tokenExpirySeconds, List<Key> keyList, List<Site> sites, IdentityScope identityScope, long maxBidstreamLifetimeSeconds, long maxSharingLifetimeSeconds, long allowClockSkewSeconds) {
this.callerSiteId = callerSiteId;
this.masterKeysetId = masterKeysetId;
this.defaultKeysetId = defaultKeysetId;
Expand All @@ -61,6 +62,10 @@ class KeyContainer {
for(Map.Entry<Integer, List<Key>> entry : keysByKeyset.entrySet()) {
entry.getValue().sort(Comparator.comparing(Key::getActivates));
}

for (Site site : sites) {
this.siteIdToSite.put(site.getId(), site);
}
}


Expand All @@ -82,6 +87,16 @@ public Key getMasterKey(Instant now)
return getKeysetActiveKey(masterKeysetId, now);
}

public boolean isDomainOrAppNameAllowedForSite(int siteId, String domainOrAppName) {
if (domainOrAppName == null) {
return false;
}
if (siteIdToSite.containsKey(siteId)) {
return siteIdToSite.get(siteId).allowDomainOrAppName(domainOrAppName);
}
return false;
}

private Key getKeysetActiveKey(int keysetId, Instant now)
{
List<Key> keyset = keysByKeyset.get(keysetId);
Expand Down
27 changes: 26 additions & 1 deletion src/main/java/com/uid2/client/KeyParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,35 @@ static KeyContainer parse(InputStream stream) {
keys.add(key);
}

return new KeyContainer(callerSiteId, masterKeysetId, defaultKeysetId, tokenExpirySeconds, keys, identityScope, maxBidstreamLifetimeSeconds, maxSharingLifetimeSeconds, allowClockSkewSeconds);
JsonArray sitesJson = body.getAsJsonArray("site_data");
List<Site> sites = new ArrayList<>();
if (!isNull(sitesJson)) {
for (JsonElement siteJson : sitesJson.asList()) {
Site site = getSiteFromJson(siteJson.getAsJsonObject());
if (site != null) {
sites.add(site);
}
}
}

return new KeyContainer(callerSiteId, masterKeysetId, defaultKeysetId, tokenExpirySeconds, keys, sites, identityScope, maxBidstreamLifetimeSeconds, maxSharingLifetimeSeconds, allowClockSkewSeconds);
}
}

private static Site getSiteFromJson(JsonObject siteJson) {
int siteId = getAsInt(siteJson, "id");
if (siteId == 0) {
return null;
}
JsonArray domainOrAppNamesJArray = siteJson.getAsJsonArray("domain_names");
List<String> domainOrAppNames = new ArrayList<>();
for (int i = 0; i < domainOrAppNamesJArray.size(); ++i) {
domainOrAppNames.add(domainOrAppNamesJArray.get(i).getAsString());
}

return new Site(siteId, domainOrAppNames);
}

static private int getAsInt(JsonObject body, String memberName) {
JsonElement element = body.get(memberName);
return isNull(element) ? 0 : element.getAsInt();
Expand Down
23 changes: 23 additions & 0 deletions src/main/java/com/uid2/client/Site.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.uid2.client;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import lombok.Getter;

@Getter
public class Site {
private final int id;

private final Set<String> domainOrAppNames;

public Site(int id, List<String> domainOrAppNames) {
this.id = id;
this.domainOrAppNames = new HashSet<>(domainOrAppNames);
}

public boolean allowDomainOrAppName(String domainOrAppName) {
// Using streams because HashSet's contains() is case sensitive
return domainOrAppNames.stream().anyMatch(domainOrAppName::equalsIgnoreCase);
}
}
4 changes: 2 additions & 2 deletions src/main/java/com/uid2/client/TokenHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TokenHelper {
this.uid2Helper = new Uid2Helper(base64SecretKey);
}

DecryptionResponse decrypt(String token, Instant now, String domainNameFromBidRequest, ClientType clientType) {
DecryptionResponse decrypt(String token, Instant now, String domainOrAppNameFromBidRequest, ClientType clientType) {
KeyContainer keyContainer = this.container.get();
if (keyContainer == null) {
return DecryptionResponse.makeError(DecryptionStatus.NOT_INITIALIZED);
Expand All @@ -26,7 +26,7 @@ DecryptionResponse decrypt(String token, Instant now, String domainNameFromBidRe
}

try {
return Uid2Encryption.decrypt(token, keyContainer, now, keyContainer.getIdentityScope(), domainNameFromBidRequest, clientType);
return Uid2Encryption.decrypt(token, keyContainer, now, keyContainer.getIdentityScope(), domainOrAppNameFromBidRequest, clientType);
} catch (Exception e) {
return DecryptionResponse.makeError(DecryptionStatus.INVALID_PAYLOAD);
}
Expand Down
32 changes: 24 additions & 8 deletions src/main/java/com/uid2/client/Uid2Encryption.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Uid2Encryption {
public static final int GCM_AUTHTAG_LENGTH = 16;
public static final int GCM_IV_LENGTH = 12;

static DecryptionResponse decrypt(String token, KeyContainer keys, Instant now, IdentityScope identityScope, String domainName, ClientType clientType) throws Exception {
static DecryptionResponse decrypt(String token, KeyContainer keys, Instant now, IdentityScope identityScope, String domainOrAppName, ClientType clientType) throws Exception {

if (token.length() < 4)
{
Expand All @@ -33,18 +33,18 @@ static DecryptionResponse decrypt(String token, KeyContainer keys, Instant now,

if (data[0] == 2)
{
return decryptV2(Base64.getDecoder().decode(token), keys, now, domainName, clientType);
return decryptV2(Base64.getDecoder().decode(token), keys, now, domainOrAppName, clientType);
}
//java byte is signed so we wanna convert to unsigned before checking the enum
int unsignedByte = ((int) data[1]) & 0xff;
if (unsignedByte == AdvertisingTokenVersion.V3.value())
{
return decryptV3(Base64.getDecoder().decode(token), keys, now, identityScope, domainName, clientType, 3);
return decryptV3(Base64.getDecoder().decode(token), keys, now, identityScope, domainOrAppName, clientType, 3);
}
else if (unsignedByte == AdvertisingTokenVersion.V4.value())
{
// Accept either base64 or base64url encoding.
return decryptV3(Base64.getDecoder().decode(base64UrlToBase64(token)), keys, now, identityScope, domainName, clientType, 4);
return decryptV3(Base64.getDecoder().decode(base64UrlToBase64(token)), keys, now, identityScope, domainOrAppName, clientType, 4);
}

return DecryptionResponse.makeError(DecryptionStatus.VERSION_NOT_SUPPORTED);
Expand All @@ -56,7 +56,7 @@ static String base64UrlToBase64(String value) {
.replace('_', '/');
}

static DecryptionResponse decryptV2(byte[] encryptedId, KeyContainer keys, Instant now, String domainName, ClientType clientType) throws Exception {
static DecryptionResponse decryptV2(byte[] encryptedId, KeyContainer keys, Instant now, String domainOrAppName, ClientType clientType) throws Exception {
try {
ByteBuffer rootReader = ByteBuffer.wrap(encryptedId);
int version = (int) rootReader.get();
Expand Down Expand Up @@ -108,6 +108,9 @@ static DecryptionResponse decryptV2(byte[] encryptedId, KeyContainer keys, Insta
if (now.isAfter(expiry)) {
return DecryptionResponse.makeError(DecryptionStatus.EXPIRED_TOKEN, established, siteId, siteKey.getSiteId(), null, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
}
if (!isDomainOrAppNameAllowedForSite(clientType, privacyBits.isClientSideGenerated(), siteId, domainOrAppName, keys)) {
return DecryptionResponse.makeError(DecryptionStatus.DOMAIN_OR_APP_NAME_CHECK_FAILED, established, siteId, siteKey.getSiteId(), null, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
}

if (!doesTokenHaveValidLifetime(clientType, keys, now, expiry, now)) {
return DecryptionResponse.makeError(DecryptionStatus.INVALID_TOKEN_LIFETIME, established, siteId, siteKey.getSiteId(), null, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
Expand All @@ -119,7 +122,7 @@ static DecryptionResponse decryptV2(byte[] encryptedId, KeyContainer keys, Insta
}
}

static DecryptionResponse decryptV3(byte[] encryptedId, KeyContainer keys, Instant now, IdentityScope identityScope, String domainName, ClientType clientType, int advertisingTokenVersion) {
static DecryptionResponse decryptV3(byte[] encryptedId, KeyContainer keys, Instant now, IdentityScope identityScope, String domainOrAppName, ClientType clientType, int advertisingTokenVersion) {
try {
final IdentityType identityType = getIdentityType(encryptedId);
final ByteBuffer rootReader = ByteBuffer.wrap(encryptedId);
Expand Down Expand Up @@ -174,6 +177,9 @@ static DecryptionResponse decryptV3(byte[] encryptedId, KeyContainer keys, Insta
if (now.isAfter(expiry)) {
return DecryptionResponse.makeError(DecryptionStatus.EXPIRED_TOKEN, established, siteId, siteKey.getSiteId(), identityType, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
}
if (!isDomainOrAppNameAllowedForSite(clientType, privacyBits.isClientSideGenerated(), siteId, domainOrAppName, keys)) {
return DecryptionResponse.makeError(DecryptionStatus.DOMAIN_OR_APP_NAME_CHECK_FAILED, established, siteId, siteKey.getSiteId(), identityType, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
}

if (!doesTokenHaveValidLifetime(clientType, keys, generated, expiry, now)) {
return DecryptionResponse.makeError(DecryptionStatus.INVALID_TOKEN_LIFETIME, generated, siteId, siteKey.getSiteId(), identityType, advertisingTokenVersion, privacyBits.isClientSideGenerated(), expiry);
Expand Down Expand Up @@ -220,7 +226,7 @@ else if (!keys.isValid(now))
}


static EncryptionDataResponse encryptData(EncryptionDataRequest request, KeyContainer keys, IdentityScope identityScope, String domainName, ClientType clientType) {
static EncryptionDataResponse encryptData(EncryptionDataRequest request, KeyContainer keys, IdentityScope identityScope, String domainOrAppName, ClientType clientType) {
if (request.getData() == null) {
throw new IllegalArgumentException("data to encrypt must not be null");
}
Expand All @@ -241,7 +247,7 @@ static EncryptionDataResponse encryptData(EncryptionDataRequest request, KeyCont
siteKeySiteId = siteId;
} else {
try {
DecryptionResponse decryptedToken = decrypt(request.getAdvertisingToken(), keys, now, identityScope, domainName, clientType);
DecryptionResponse decryptedToken = decrypt(request.getAdvertisingToken(), keys, now, identityScope, domainOrAppName, clientType);
if (!decryptedToken.isSuccess()) {
return EncryptionDataResponse.makeError(EncryptionStatus.TOKEN_DECRYPT_FAILURE);
}
Expand Down Expand Up @@ -408,6 +414,16 @@ public CryptoException(Throwable inner) {
}
}

private static boolean isDomainOrAppNameAllowedForSite(ClientType clientType, boolean isClientSideGenerated, Integer siteId, String domainOrAppName, KeyContainer keys) {
if (!isClientSideGenerated) {
return true;
} else if (!clientType.equals(ClientType.BIDSTREAM) && !clientType.equals(ClientType.LEGACY)) {
return true;
} else {
return keys.isDomainOrAppNameAllowedForSite(siteId, domainOrAppName);
}
}

private static boolean doesTokenHaveValidLifetime(ClientType clientType, KeyContainer keys, Instant generatedOrNow, Instant expiry, Instant now) {
long maxLifetimeSeconds;
switch (clientType) {
Expand Down

0 comments on commit 92278a7

Please sign in to comment.