Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request level plugin execution #4518

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static software.amazon.awssdk.codegen.internal.Constant.EVENT_PUBLISHER_PARAM_NAME;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.addS3ArnableFieldCode;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;
import static software.amazon.awssdk.codegen.poet.client.SyncClientClass.addConfigurationUpdater;
import static software.amazon.awssdk.codegen.poet.client.SyncClientClass.getProtocolSpecs;

import com.squareup.javapoet.ClassName;
Expand All @@ -44,6 +45,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.reactivestreams.Publisher;
Expand All @@ -70,7 +72,9 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils;
import software.amazon.awssdk.core.async.SdkPublisher;
Expand All @@ -96,6 +100,7 @@ public final class AsyncClientClass extends AsyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public AsyncClientClass(GeneratorTaskParams dependencies) {
Expand All @@ -106,6 +111,7 @@ public AsyncClientClass(GeneratorTaskParams dependencies) {
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
}

@Override
Expand Down Expand Up @@ -140,7 +146,10 @@ protected void addFields(TypeSpec.Builder type) {
.addField(AsyncClientHandler.class, "clientHandler", PRIVATE, FINAL)
.addField(protocolSpec.protocolFactory(model))
.addField(SdkClientConfiguration.class, "clientConfiguration", PRIVATE, FINAL)
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL);
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL)
.addField(ParameterizedTypeName.get(BiFunction.class, SdkRequest.class,
SdkClientConfiguration.class, SdkClientConfiguration.class),
"clientConfigurationForRequest", PRIVATE, FINAL);

// Kinesis doesn't support CBOR for STS yet so need another protocol factory for JSON
if (model.getMetadata().isCborProtocol()) {
Expand Down Expand Up @@ -210,6 +219,7 @@ private MethodSpec constructor(TypeSpec.Builder classBuilder) {
.addStatement("this.clientHandler = new $T(clientConfiguration)", AwsAsyncClientHandler.class)
.addStatement("this.clientConfiguration = clientConfiguration")
.addStatement("this.serviceClientConfiguration = serviceClientConfiguration");
builder.addCode(addConfigurationUpdater(configurationUtils.serviceClientConfigurationBuilderClassName()));
FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
if (model.getMetadata().isJsonProtocol()) {
builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
Expand Down Expand Up @@ -294,7 +304,8 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation

builder.addModifiers(PUBLIC)
.addAnnotation(Override.class);

builder.addStatement("$T clientConfiguration = this.clientConfigurationForRequest.apply($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
builder.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;

import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.WildcardTypeName;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand All @@ -52,9 +55,15 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkServiceClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.config.internal.ConfigurationUpdater;
import software.amazon.awssdk.core.client.config.internal.SdkClientConfigurationUtil;
import software.amazon.awssdk.core.client.handler.SyncClientHandler;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
Expand All @@ -73,6 +82,7 @@ public class SyncClientClass extends SyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public SyncClientClass(GeneratorTaskParams taskParams) {
Expand All @@ -82,6 +92,7 @@ public SyncClientClass(GeneratorTaskParams taskParams) {
this.className = poetExtensions.getClientClass(model.getMetadata().getSyncClient());
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

Expand Down Expand Up @@ -113,7 +124,10 @@ protected void addFields(TypeSpec.Builder type) {
.addField(SyncClientHandler.class, "clientHandler", PRIVATE, FINAL)
.addField(protocolSpec.protocolFactory(model))
.addField(SdkClientConfiguration.class, "clientConfiguration", PRIVATE, FINAL)
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL);
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL)
.addField(ParameterizedTypeName.get(BiFunction.class, SdkRequest.class,
SdkClientConfiguration.class, SdkClientConfiguration.class),
"clientConfigurationForRequest", PRIVATE, FINAL);
}

@Override
Expand Down Expand Up @@ -176,6 +190,8 @@ private MethodSpec constructor() {
.addStatement("this.clientHandler = new $T(clientConfiguration)", protocolSpec.getClientHandlerClass())
.addStatement("this.clientConfiguration = clientConfiguration")
.addStatement("this.serviceClientConfiguration = serviceClientConfiguration");

builder.addCode(addConfigurationUpdater(configurationUtils.serviceClientConfigurationBuilderClassName()));
FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
if (model.getMetadata().isJsonProtocol()) {
builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
Expand Down Expand Up @@ -207,6 +223,26 @@ private MethodSpec constructor() {
return builder.build();
}

static CodeBlock addConfigurationUpdater(TypeName serviceClientConfigurationBuilderClassName) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("$T configurationUpdater = ",
ParameterizedTypeName.get(ConfigurationUpdater.class, SdkServiceClientConfiguration.Builder.class));
builder.add("(consumer, configBuilder) -> {\n$>")
.addStatement("$1T.BuilderInternal serviceConfigBuilder = $1T.builder(configBuilder)",
serviceClientConfigurationBuilderClassName)
.addStatement("consumer.accept(serviceConfigBuilder)")
.addStatement("return serviceConfigBuilder.buildSdkClientConfiguration()")
.add("$<};\n");
builder.add("this.clientConfigurationForRequest = (request, config) -> {\n$>");
builder.addStatement("$T plugins = request.overrideConfiguration()\n"
+ ".map(c -> c.registeredPlugins()).orElse(Collections.emptyList())",
ParameterizedTypeName.get(List.class, SdkPlugin.class));
builder.addStatement("return $T.invokePlugins(config, plugins, configurationUpdater)", SdkClientConfigurationUtil.class);

builder.add("$<};\n");
return builder.build();
}

@Override
protected List<MethodSpec> operations() {
return model.getOperations().values().stream()
Expand Down Expand Up @@ -289,6 +325,8 @@ private MethodSpec traditionalMethod(OperationModel opModel) {
method.endControlFlow();
}

method.addStatement("$T clientConfiguration = this.clientConfigurationForRequest.apply($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
method.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)\n", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -259,6 +260,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(hasInitialRequestEvent(opModel, isRestJson))
.add(".withResponseHandler($L)\n", responseHandlerName(opModel, isRestJson))
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -156,6 +157,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(".withResponseHandler(responseHandler)\n")
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
hostPrefixExpression(opModel) +
discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down Expand Up @@ -198,11 +199,11 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
if (opModel.hasEventStreamOutput()) {
executionResponseTransformerName = "restAsyncResponseTransformer";
}

builder.add("\n\n$T<$T> executeFuture = clientHandler.execute(new $T<$T, $T>()\n",
CompletableFuture.class, executeFutureValueType,
ClientExecutionParams.class, requestType, pojoResponseType)
.add(".withOperationName(\"$N\")\n", opModel.getOperationName())
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMarshaller($L)\n", asyncMarshaller(intermediateModel, opModel, marshaller, "protocolFactory"));

if (opModel.hasEventStreamOutput()) {
Expand Down
Loading
Loading