Skip to content

Commit

Permalink
Migrate Lambda mediator to AWS SDK 2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
msm1992 committed Feb 19, 2025
1 parent 522d8ed commit d3f444a
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,6 @@
*/
package org.wso2.carbon.apimgt.gateway.mediators;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.lambda.AWSLambda;
import com.amazonaws.services.lambda.AWSLambdaClientBuilder;
import com.amazonaws.services.lambda.model.InvocationType;
import com.amazonaws.services.lambda.model.InvokeRequest;
import com.amazonaws.services.lambda.model.InvokeResult;
import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest;
import com.amazonaws.services.securitytoken.model.AssumeRoleResult;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.apache.axiom.om.OMElement;
Expand All @@ -58,10 +39,29 @@
import org.wso2.carbon.apimgt.gateway.internal.ServiceReferenceHolder;
import org.wso2.carbon.apimgt.gateway.utils.redis.RedisCacheUtils;
import org.wso2.carbon.apimgt.impl.APIConstants;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.lambda.LambdaClient;
import software.amazon.awssdk.services.lambda.model.InvocationType;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
import software.amazon.awssdk.services.sts.model.Credentials;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Iterator;
import java.util.Set;
import java.util.TreeMap;
Expand All @@ -79,7 +79,7 @@ public class AWSLambdaMediator extends AbstractMediator {
private String roleArn = "";
private String roleSessionName = "";
private String roleRegion = "";
private int resourceTimeout = APIConstants.AWS_DEFAULT_CONNECTION_TIMEOUT;
private Duration resourceTimeout = Duration.ofMillis(APIConstants.AWS_DEFAULT_CONNECTION_TIMEOUT);
private boolean isContentEncodingEnabled = false;
private static final String PATH_PARAMETERS = "pathParameters";
private static final String QUERY_STRING_PARAMETERS = "queryStringParameters";
Expand Down Expand Up @@ -165,15 +165,15 @@ public boolean mediate(MessageContext messageContext) {
log.debug("Passing the payload " + payload.toString() + " to AWS Lambda function with resource name "
+ resourceName);
}
InvokeResult invokeResult = invokeLambda(payload.toString());
InvokeResponse invokeResult = invokeLambda(payload.toString());

if (invokeResult != null) {
if (log.isDebugEnabled()) {
log.debug("AWS Lambda function: " + resourceName + " is invoked successfully.");
}
JsonUtil.getNewJsonPayload(axis2MessageContext, new String(invokeResult.getPayload().array()),
JsonUtil.getNewJsonPayload(axis2MessageContext, new String(invokeResult.payload().asByteArray(), StandardCharsets.UTF_8),
true, true);
axis2MessageContext.setProperty(APIMgtGatewayConstants.HTTP_SC, invokeResult.getStatusCode());
axis2MessageContext.setProperty(APIMgtGatewayConstants.HTTP_SC, invokeResult.statusCode());
axis2MessageContext.setProperty(APIMgtGatewayConstants.REST_MESSAGE_TYPE, APIConstants.APPLICATION_JSON_MEDIA_TYPE);
axis2MessageContext.setProperty(APIMgtGatewayConstants.REST_CONTENT_TYPE, APIConstants.APPLICATION_JSON_MEDIA_TYPE);
axis2MessageContext.removeProperty(APIConstants.NO_ENTITY_BODY);
Expand All @@ -198,39 +198,39 @@ public boolean mediate(MessageContext messageContext) {
* @param payload - input parameters to pass to AWS Lambda function as a JSONString
* @return InvokeResult
*/
private InvokeResult invokeLambda(String payload) {
private InvokeResponse invokeLambda(String payload) {
try {
// Validate resource timeout and set client configuration
if (resourceTimeout < 1000 || resourceTimeout > 900000) {
if (resourceTimeout.toMillis() < 1000 || resourceTimeout.toMillis() > 900000) {
setResourceTimeout(APIConstants.AWS_DEFAULT_CONNECTION_TIMEOUT);
}
ClientConfiguration clientConfig = new ClientConfiguration();
clientConfig.setSocketTimeout(resourceTimeout);
ClientOverrideConfiguration clientConfig = ClientOverrideConfiguration.builder()
.apiCallTimeout(resourceTimeout).build();

AWSLambda awsLambdaClient;
LambdaClient awsLambdaClient;
if (StringUtils.isEmpty(accessKey) && StringUtils.isEmpty(secretKey)) {
if (log.isDebugEnabled()) {
log.debug("Using temporary credentials supplied by the IAM role attached to AWS instance");
}
if (StringUtils.isEmpty(roleArn) && StringUtils.isEmpty(roleSessionName)
&& StringUtils.isEmpty(roleRegion)) {
awsLambdaClient = AWSLambdaClientBuilder.standard()
.withCredentials(DefaultAWSCredentialsProviderChain.getInstance())
.withClientConfiguration(clientConfig)
awsLambdaClient = LambdaClient.builder()
.credentialsProvider(DefaultCredentialsProvider.create())
.httpClientBuilder(ApacheHttpClient.builder())
.overrideConfiguration(clientConfig)
.build();
} else if (StringUtils.isNotEmpty(roleArn) && StringUtils.isNotEmpty(roleSessionName)
&& StringUtils.isNotEmpty(roleRegion)) {
Region region = new DefaultAwsRegionProviderChain().getRegion();
Credentials sessionCredentials = getSessionCredentials(
DefaultAWSCredentialsProviderChain.getInstance(), roleArn, roleSessionName,
String.valueOf(Regions.getCurrentRegion()));
BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(
sessionCredentials.getAccessKeyId(),
sessionCredentials.getSecretAccessKey(),
sessionCredentials.getSessionToken());
awsLambdaClient = AWSLambdaClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(basicSessionCredentials))
.withClientConfiguration(clientConfig)
.withRegion(roleRegion)
DefaultCredentialsProvider.create(), roleArn, roleSessionName,
String.valueOf(region));
AwsSessionCredentials basicSessionCredentials = AwsSessionCredentials.create(sessionCredentials.accessKeyId(), sessionCredentials.secretAccessKey(), sessionCredentials.sessionToken());
awsLambdaClient = LambdaClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(basicSessionCredentials))
.httpClientBuilder(ApacheHttpClient.builder())
.overrideConfiguration(clientConfig)
.region(Region.of(roleRegion))
.build();
} else {
log.error("Missing AWS STS configurations");
Expand All @@ -241,26 +241,25 @@ private InvokeResult invokeLambda(String payload) {
if (log.isDebugEnabled()) {
log.debug("Using user given stored credentials");
}
BasicAWSCredentials awsCredentials = new BasicAWSCredentials(accessKey, secretKey);
AwsBasicCredentials awsCredentials = AwsBasicCredentials.create(accessKey, secretKey);
if (StringUtils.isEmpty(roleArn) && StringUtils.isEmpty(roleSessionName)
&& StringUtils.isEmpty(roleRegion)) {
awsLambdaClient = AWSLambdaClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(awsCredentials))
.withClientConfiguration(clientConfig)
.withRegion(region)
awsLambdaClient = LambdaClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCredentials))
.httpClientBuilder(ApacheHttpClient.builder())
.overrideConfiguration(clientConfig)
.region(Region.of(region))
.build();
} else if (StringUtils.isNotEmpty(roleArn) && StringUtils.isNotEmpty(roleSessionName)
&& StringUtils.isNotEmpty(roleRegion)) {
Credentials sessionCredentials = getSessionCredentials(
new AWSStaticCredentialsProvider(awsCredentials), roleArn, roleSessionName, region);
BasicSessionCredentials basicSessionCredentials = new BasicSessionCredentials(
sessionCredentials.getAccessKeyId(),
sessionCredentials.getSecretAccessKey(),
sessionCredentials.getSessionToken());
awsLambdaClient = AWSLambdaClientBuilder.standard()
.withCredentials(new AWSStaticCredentialsProvider(basicSessionCredentials))
.withClientConfiguration(clientConfig)
.withRegion(roleRegion)
StaticCredentialsProvider.create(awsCredentials), roleArn, roleSessionName, region);
AwsSessionCredentials basicSessionCredentials = AwsSessionCredentials.create(sessionCredentials.accessKeyId(), sessionCredentials.secretAccessKey(), sessionCredentials.sessionToken());
awsLambdaClient = LambdaClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(basicSessionCredentials))
.httpClientBuilder(ApacheHttpClient.builder())
.overrideConfiguration(clientConfig)
.region(Region.of(roleRegion))
.build();
} else {
log.error("Missing AWS STS configurations");
Expand All @@ -270,20 +269,22 @@ private InvokeResult invokeLambda(String payload) {
log.error("Missing AWS Credentials");
return null;
}
InvokeRequest invokeRequest = new InvokeRequest()
.withFunctionName(resourceName)
.withPayload(payload)
.withInvocationType(InvocationType.RequestResponse)
.withSdkClientExecutionTimeout(resourceTimeout);

SdkBytes payloadBytes = SdkBytes.fromUtf8String(payload);
InvokeRequest invokeRequest = InvokeRequest.builder()
.functionName(resourceName)
.payload(payloadBytes)
.invocationType(InvocationType.REQUEST_RESPONSE)
.build();
return awsLambdaClient.invoke(invokeRequest);
} catch (SdkClientException e) {
} catch (SdkClientException | URISyntaxException e) {
log.error("Error while invoking the lambda function", e);
}
return null;
}

private Credentials getSessionCredentials(AWSCredentialsProvider credentialsProvider, String roleArn,
String roleSessionName, String region) {
private Credentials getSessionCredentials(AwsCredentialsProvider credentialsProvider, String roleArn,
String roleSessionName, String region) throws URISyntaxException {
Credentials sessionCredentials = null;
if (ServiceReferenceHolder.getInstance().isRedisEnabled()) {
Object previousCredentialsObject = new RedisCacheUtils(ServiceReferenceHolder.getInstance().getRedisPool())
Expand All @@ -295,30 +296,30 @@ private Credentials getSessionCredentials(AWSCredentialsProvider credentialsProv
sessionCredentials = CredentialsCache.getInstance().getCredentialsMap().get(roleSessionName);
}
if (sessionCredentials != null) {
long expirationTime = sessionCredentials.getExpiration().getTime();
long expirationTime = sessionCredentials.expiration().toEpochMilli();
long currentTime = System.currentTimeMillis();
long timeDifference = expirationTime - currentTime;
if (timeDifference > 1000) {
return sessionCredentials;
}
}
AWSSecurityTokenService awsSTSClient;
StsClient awsSTSClient;
if (StringUtils.isEmpty(region)) {
awsSTSClient = AWSSecurityTokenServiceClientBuilder.standard()
.withCredentials(credentialsProvider)
awsSTSClient = StsClient.builder()
.credentialsProvider(credentialsProvider)
.build();
} else {
awsSTSClient = AWSSecurityTokenServiceClientBuilder.standard()
.withCredentials(credentialsProvider)
.withEndpointConfiguration(new EndpointConfiguration("https://sts." + region + ".amazonaws.com",
region))
awsSTSClient = StsClient.builder()
.credentialsProvider(credentialsProvider)
.endpointOverride(new URI("https://sts." + region + ".amazonaws.com"))
.build();
}
AssumeRoleRequest roleRequest = new AssumeRoleRequest()
.withRoleArn(roleArn)
.withRoleSessionName(roleSessionName);
AssumeRoleResult assumeRoleResult = awsSTSClient.assumeRole(roleRequest);
sessionCredentials = assumeRoleResult.getCredentials();
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.roleArn(roleArn)
.roleSessionName(roleSessionName)
.build();
AssumeRoleResponse assumeRoleResult = awsSTSClient.assumeRole(roleRequest);
sessionCredentials = assumeRoleResult.credentials();
if (ServiceReferenceHolder.getInstance().isRedisEnabled()) {
new RedisCacheUtils(ServiceReferenceHolder.getInstance().getRedisPool())
.addObject(roleSessionName, sessionCredentials);
Expand Down Expand Up @@ -416,7 +417,7 @@ public String getResourceName() {
}

public int getResourceTimeout() {
return resourceTimeout;
return (int) resourceTimeout.toMillis();
}

public void setAccessKey(String accessKey) {
Expand Down Expand Up @@ -448,7 +449,7 @@ public void setResourceName(String resourceName) {
}

public void setResourceTimeout(int resourceTimeout) {
this.resourceTimeout = resourceTimeout;
this.resourceTimeout = Duration.ofMillis(resourceTimeout);
}

public void setIsContentEncodingEnabled(boolean isContentEncodingEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
*/
package org.wso2.carbon.apimgt.gateway.mediators;

import com.amazonaws.services.securitytoken.model.Credentials;
import software.amazon.awssdk.services.sts.model.Credentials;

import java.util.HashMap;
import java.util.Map;
Expand Down
Loading

0 comments on commit d3f444a

Please sign in to comment.