Skip to content

Commit

Permalink
Merge pull request #3562 from Thushani-Jayasekera/ws-apikey-enforcer
Browse files Browse the repository at this point in the history
[Choreo] Support passing API-Key for WebSocket Requests using sec-websocket-protocol Header
  • Loading branch information
renuka-fernando authored Aug 9, 2024
2 parents 474372c + 6bb2e73 commit 56727c9
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class RequestContext {
// For example, reason for denying a request
private String extAuthDetails;

private Map<String, String> responseHeadersToAddMap;

/**
* The dynamic metadata sent from enforcer are stored in this metadata map.
* @return dynamic metadata map
Expand Down Expand Up @@ -358,6 +360,20 @@ public void setExtAuthDetails(String extAuthDetails) {
this.extAuthDetails = extAuthDetails;
}

/**
* Specifies if headers needs to be added for the response based on request
*
* @return response headers to add map
*/
public Map<String, String> getResponseHeadersToAddMap() {
return responseHeadersToAddMap;
}

public void setResponseHeadersToAddMap(Map<String, String> responseHeadersToAddMap) {
this.responseHeadersToAddMap = responseHeadersToAddMap;
}


/**
* Implements builder pattern to build an {@link RequestContext} object.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class ResponseObject {
private String requestPath;
private String apiUuid;
private String extAuthDetails;
private Map<String, String> responseHeadersToAddMap;

public ArrayList<String> getRemoveHeaderMap() {
return removeHeaderMap;
Expand All @@ -48,6 +49,14 @@ public void setRemoveHeaderMap(ArrayList<String> removeHeaderMap) {
this.removeHeaderMap = removeHeaderMap;
}

public Map<String, String> getResponseHeadersToAddMap() {
return responseHeadersToAddMap;
}

public void setResponseHeadersToAddMap(Map<String, String> responseHeadersToAddMap) {
this.responseHeadersToAddMap = responseHeadersToAddMap;
}

public ResponseObject(String correlationID) {
this.correlationID = correlationID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,17 @@ public ResponseObject process(RequestContext requestContext) {
Utils.populateRemoveAndProtectedHeaders(requestContext);

if (executeFilterChain(requestContext)) {
responseObject.setRemoveHeaderMap(requestContext.getRemoveHeaders());
responseObject.setQueryParamsToRemove(requestContext.getQueryParamsToRemove());
responseObject.setQueryParamMap(requestContext.getQueryParameters());
responseObject.setStatusCode(APIConstants.StatusCodes.OK.getCode());
if (requestContext.getAddHeaders() != null && requestContext.getAddHeaders().size() > 0) {
responseObject.setHeaderMap(requestContext.getAddHeaders());
}
if (requestContext.getResponseHeadersToAddMap() != null
&& requestContext.getResponseHeadersToAddMap().size() > 0) {
responseObject.setResponseHeadersToAddMap(requestContext.getResponseHeadersToAddMap());
}
logger.debug("ext_authz metadata: {}", requestContext.getMetadataMap());
responseObject.setMetaDataMap(requestContext.getMetadataMap());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,7 @@ public class Constants {
public static final String PROP_CON_FACTORY = "connectionfactory.TopicConnectionFactory";
public static final String DEFAULT_DESTINATION_TYPE = "Topic";
public static final String DEFAULT_CON_FACTORY_JNDI_NAME = "TopicConnectionFactory";

// keyword to identify API-Key sent in sec-websocket-protocol header
public static final String WS_API_KEY_IDENTIFIER = "choreo-internal-API-Key";
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ public class HttpConstants {
public static final String X_REQUEST_ID_HEADER = "x-request-id";
public static final String APPLICATION_JSON = "application/json";
public static final String BASIC_LOWER = "basic";
public static final String WEBSOCKET_PROTOCOL_HEADER = "sec-websocket-protocol";
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private CheckResponse buildResponse(CheckRequest request, ResponseObject respons
.build();
} else {
OkHttpResponse.Builder okResponseBuilder = OkHttpResponse.newBuilder();

// If the user is sending the APIKey credentials within query parameters, those query parameters should
// not be sent to the backend. Hence, the :path header needs to be constructed again removing the apiKey
// query parameter. In this scenario, apiKey query parameter is sent within the property called
Expand All @@ -175,6 +175,16 @@ private CheckResponse buildResponse(CheckRequest request, ResponseObject respons
}
);
}

if (responseObject.getResponseHeadersToAddMap() != null) {
responseObject.getResponseHeadersToAddMap().forEach((key, value) -> {
HeaderValueOption headerValueOption = HeaderValueOption.newBuilder()
.setHeader(HeaderValue.newBuilder().setKey(key).setValue(value).build())
.build();
okResponseBuilder.addResponseHeadersToAdd(headerValueOption);
}
);
}
okResponseBuilder.addAllHeadersToRemove(responseObject.getRemoveHeaderMap());
if (responseObject.getMetaDataMap() != null) {
responseObject.getMetaDataMap().forEach((key, value) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.wso2.choreo.connect.enforcer.config.EnforcerConfig;
import org.wso2.choreo.connect.enforcer.constants.APIConstants;
import org.wso2.choreo.connect.enforcer.constants.APISecurityConstants;
import org.wso2.choreo.connect.enforcer.constants.Constants;
import org.wso2.choreo.connect.enforcer.constants.HttpConstants;
import org.wso2.choreo.connect.enforcer.dto.APIKeyValidationInfoDTO;
import org.wso2.choreo.connect.enforcer.dto.JWTTokenPayloadInfo;
import org.wso2.choreo.connect.enforcer.exception.APISecurityException;
Expand All @@ -47,6 +49,10 @@
import org.wso2.choreo.connect.enforcer.util.FilterUtils;

import java.text.ParseException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Implements the authenticator interface to authenticate request using an Internal Key.
Expand All @@ -69,8 +75,16 @@ public InternalAPIKeyAuthenticator(String securityParam) {

@Override
public boolean canAuthenticate(RequestContext requestContext) {
String apiType = requestContext.getMatchedAPI().getApiType();
String internalKey = requestContext.getHeaders().get(
ConfigHolder.getInstance().getConfig().getAuthHeader().getTestConsoleHeaderName().toLowerCase());
if (apiType.equalsIgnoreCase("WS")) {
if (internalKey == null) {
internalKey = extractInternalKeyInWSProtocolHeader(requestContext);
}
addWSProtocolResponseHeaderIfRequired(requestContext);
}

return isAPIKey(internalKey);
}

Expand Down Expand Up @@ -281,13 +295,65 @@ public String getName() {
}

private String extractInternalKey(RequestContext requestContext) {
String internalKey = requestContext.getHeaders().get(securityParam);
String internalKey;
internalKey = requestContext.getHeaders().get(securityParam);
if (internalKey != null) {
return internalKey.trim();
}
if (requestContext.getMatchedAPI().getApiType().equalsIgnoreCase("WS")) {
internalKey = extractInternalKeyInWSProtocolHeader(requestContext);
if (internalKey != null && !internalKey.isEmpty()) {
String protocols = getProtocolsToSetInRequestHeaders(requestContext);
if (protocols != null) {
requestContext.addOrModifyHeaders(HttpConstants.WEBSOCKET_PROTOCOL_HEADER, protocols);
}
return internalKey.trim();
}
}
return null;
}

public String extractInternalKeyInWSProtocolHeader(RequestContext requestContext) {
String protocolHeader = requestContext.getHeaders().get(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER);
if (protocolHeader != null) {
String[] secProtocolHeaderValues = protocolHeader.split(",");
if (secProtocolHeaderValues.length > 1 && secProtocolHeaderValues[0].equals(
Constants.WS_API_KEY_IDENTIFIER)) {
return secProtocolHeaderValues[1].trim();
}
}
return "";
}

public String getProtocolsToSetInRequestHeaders(RequestContext requestContext) {
String[] secProtocolHeaderValues = requestContext.getHeaders().get(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER).split(",");
if (secProtocolHeaderValues.length > 2) {
return Arrays.stream(secProtocolHeaderValues, 2, secProtocolHeaderValues.length)
.collect(Collectors.joining(",")).trim();
}
return null;
}

public void addWSProtocolResponseHeaderIfRequired(RequestContext requestContext) {
String secProtocolHeader = requestContext.getHeaders().get(HttpConstants.WEBSOCKET_PROTOCOL_HEADER);
if (secProtocolHeader != null) {
String[] secProtocolHeaderValues = secProtocolHeader.split(",");
if (secProtocolHeaderValues[0].equals(Constants.WS_API_KEY_IDENTIFIER) &&
secProtocolHeaderValues.length == 2) {
Map<String, String> responseHeadersToAddMap = requestContext.getResponseHeadersToAddMap();

if (responseHeadersToAddMap == null) {
responseHeadersToAddMap = new HashMap<>();
}
responseHeadersToAddMap.put(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER, Constants.WS_API_KEY_IDENTIFIER);
requestContext.setResponseHeadersToAddMap(responseHeadersToAddMap);
}
}
}

@Override
public int getPriority() {
return -10;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.choreo.connect.enforcer.security.jwt;

import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.wso2.carbon.apimgt.common.gateway.dto.JWTConfigurationDto;
import org.wso2.choreo.connect.enforcer.commons.model.APIConfig;
import org.wso2.choreo.connect.enforcer.commons.model.RequestContext;
import org.wso2.choreo.connect.enforcer.config.ConfigHolder;
import org.wso2.choreo.connect.enforcer.config.EnforcerConfig;
import org.wso2.choreo.connect.enforcer.config.dto.CacheDto;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ConfigHolder.class})
@PowerMockIgnore("javax.management.*")
public class InternalAPIKeyAuthenticatorTest {

@Test
public void extractInternalKeyInWSProtocolHeaderTest() {
PowerMockito.mockStatic(ConfigHolder.class);
ConfigHolder configHolder = Mockito.mock(ConfigHolder.class);
EnforcerConfig enforcerConfig = Mockito.mock(EnforcerConfig.class);
CacheDto cacheDto = Mockito.mock(CacheDto.class);
Mockito.when(cacheDto.isEnabled()).thenReturn(true);
Mockito.when(enforcerConfig.getCacheDto()).thenReturn(cacheDto);
JWTConfigurationDto jwtConfigurationDto = Mockito.mock(JWTConfigurationDto.class);
Mockito.when(jwtConfigurationDto.isEnabled()).thenReturn(false);
Mockito.when(enforcerConfig.getJwtConfigurationDto()).thenReturn(jwtConfigurationDto);
Mockito.when(configHolder.getConfig()).thenReturn(enforcerConfig);
Mockito.when(ConfigHolder.getInstance()).thenReturn(configHolder);

String securityParam = "API-Key";

String mockToken = "eyJraWQiOiJnYXRld2F5XUlMyNTYifQlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYlzaGVyXC92XBpc1wvaW50ZXJuY." +
"eyJzdWIiOiJhMzllYGV2OjQ0M1wvYXBpXC9hbVwvcHVibGlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYWwta2V5Iiwia2V5dHlwZcl." +
"cnZpY2VcL3YxLjAiLCJwdWJsaXNoZXIiOiJjaG9yZW9fZGV2X2FwaW1fYWRtaW4iLCJ2ZXJzaW9uIjoidj7MIXRnS-2UWHdrmd7";

String secWebsocketProtocolHeader = "sec-websocket-protocol";

// Test case to test for an Upgrade request sent from the choreo console
// The token will be set to the sec-websocket-protocol header with choreo-internal-API-Key keyword
// the value after choreo-internal-API-Key will be the token
RequestContext.Builder builder = new RequestContext.Builder("/pets");
builder.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap = new HashMap<>();
headersMap.put(
secWebsocketProtocolHeader,
"choreo-internal-API-Key," + mockToken);
builder.headers(headersMap);
RequestContext requestContext = builder.build();
InternalAPIKeyAuthenticator internalAPIKeyAuthenticator = new InternalAPIKeyAuthenticator(securityParam);
Assert.assertEquals(internalAPIKeyAuthenticator.extractInternalKeyInWSProtocolHeader(requestContext), mockToken);

// Test case to test for an Upgrade request sent from a client with api-key
RequestContext.Builder builder2 = new RequestContext.Builder("/pets");
builder2.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap2 = new HashMap<>();
headersMap2.put(securityParam, mockToken);
builder2.headers(headersMap2);
RequestContext requestContext2 = builder2.build();
Assert.assertEquals(internalAPIKeyAuthenticator.extractInternalKeyInWSProtocolHeader(requestContext2), "");

}

@Test
public void getProtocolsToSetInRequestHeadersTest() {
PowerMockito.mockStatic(ConfigHolder.class);
ConfigHolder configHolder = Mockito.mock(ConfigHolder.class);
EnforcerConfig enforcerConfig = Mockito.mock(EnforcerConfig.class);
CacheDto cacheDto = Mockito.mock(CacheDto.class);
Mockito.when(cacheDto.isEnabled()).thenReturn(true);
Mockito.when(enforcerConfig.getCacheDto()).thenReturn(cacheDto);
JWTConfigurationDto jwtConfigurationDto = Mockito.mock(JWTConfigurationDto.class);
Mockito.when(jwtConfigurationDto.isEnabled()).thenReturn(false);
Mockito.when(enforcerConfig.getJwtConfigurationDto()).thenReturn(jwtConfigurationDto);
Mockito.when(configHolder.getConfig()).thenReturn(enforcerConfig);
Mockito.when(ConfigHolder.getInstance()).thenReturn(configHolder);

String securityParam = "API-Key";

String secWebsocketProtocolHeader = "sec-websocket-protocol";

String mockToken = "eyJraWQiOiJnYXRld2F5XUlMyNTYifQlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYlzaGVyXC92XBpc1wvaW50ZXJuY." +
"eyJzdWIiOiJhMzllYGV2OjQ0M1wvYXBpXC9hbVwvcHVibGlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYWwta2V5Iiwia2V5dHlwZcl." +
"cnZpY2VcL3YxLjAiLCJwdWJsaXNoZXIiOiJjaG9yZW9fZGV2X2FwaW1fYWRtaW4iLCJ2ZXJzaW9uIjoidj7MIXRnS-2UWHdrmd7";

RequestContext.Builder builder = new RequestContext.Builder("/pets");
builder.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap = new HashMap<>();
headersMap.put(
secWebsocketProtocolHeader,
"choreo-internal-API-Key, " + mockToken + ", " + "chat, bar");
builder.headers(headersMap);
RequestContext requestContext = builder.build();
InternalAPIKeyAuthenticator internalAPIKeyAuthenticator = new InternalAPIKeyAuthenticator(securityParam);
Assert.assertEquals(internalAPIKeyAuthenticator.getProtocolsToSetInRequestHeaders(requestContext), "chat, bar");

}
}

0 comments on commit 56727c9

Please sign in to comment.