Skip to content

Commit

Permalink
Request level plugin execution
Browse files Browse the repository at this point in the history
  • Loading branch information
sugmanue committed Oct 2, 2023
1 parent f3b5320 commit b4d643a
Show file tree
Hide file tree
Showing 42 changed files with 2,916 additions and 1,716 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static software.amazon.awssdk.codegen.internal.Constant.EVENT_PUBLISHER_PARAM_NAME;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.addS3ArnableFieldCode;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;
import static software.amazon.awssdk.codegen.poet.client.SyncClientClass.addConfigurationUpdater;
import static software.amazon.awssdk.codegen.poet.client.SyncClientClass.getProtocolSpecs;

import com.squareup.javapoet.ClassName;
Expand All @@ -44,6 +45,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.reactivestreams.Publisher;
Expand All @@ -70,7 +72,9 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils;
import software.amazon.awssdk.core.async.SdkPublisher;
Expand All @@ -96,6 +100,7 @@ public final class AsyncClientClass extends AsyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public AsyncClientClass(GeneratorTaskParams dependencies) {
Expand All @@ -106,6 +111,7 @@ public AsyncClientClass(GeneratorTaskParams dependencies) {
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
}

@Override
Expand Down Expand Up @@ -140,7 +146,10 @@ protected void addFields(TypeSpec.Builder type) {
.addField(AsyncClientHandler.class, "clientHandler", PRIVATE, FINAL)
.addField(protocolSpec.protocolFactory(model))
.addField(SdkClientConfiguration.class, "clientConfiguration", PRIVATE, FINAL)
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL);
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL)
.addField(ParameterizedTypeName.get(BiFunction.class, SdkRequest.class,
SdkClientConfiguration.class, SdkClientConfiguration.class),
"clientConfigurationForRequest", PRIVATE, FINAL);

// Kinesis doesn't support CBOR for STS yet so need another protocol factory for JSON
if (model.getMetadata().isCborProtocol()) {
Expand Down Expand Up @@ -210,6 +219,7 @@ private MethodSpec constructor(TypeSpec.Builder classBuilder) {
.addStatement("this.clientHandler = new $T(clientConfiguration)", AwsAsyncClientHandler.class)
.addStatement("this.clientConfiguration = clientConfiguration")
.addStatement("this.serviceClientConfiguration = serviceClientConfiguration");
builder.addCode(addConfigurationUpdater(configurationUtils.serviceClientConfigurationBuilderClassName()));
FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
if (model.getMetadata().isJsonProtocol()) {
builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
Expand Down Expand Up @@ -294,7 +304,8 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation

builder.addModifiers(PUBLIC)
.addAnnotation(Override.class);

builder.addStatement("$T clientConfiguration = this.clientConfigurationForRequest.apply($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
builder.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;

import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.CodeBlock;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.TypeSpec;
import com.squareup.javapoet.WildcardTypeName;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand All @@ -52,9 +55,15 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkServiceClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.config.internal.ConfigurationUpdater;
import software.amazon.awssdk.core.client.config.internal.SdkClientConfigurationUtil;
import software.amazon.awssdk.core.client.handler.SyncClientHandler;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
Expand All @@ -73,6 +82,7 @@ public class SyncClientClass extends SyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public SyncClientClass(GeneratorTaskParams taskParams) {
Expand All @@ -82,6 +92,7 @@ public SyncClientClass(GeneratorTaskParams taskParams) {
this.className = poetExtensions.getClientClass(model.getMetadata().getSyncClient());
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

Expand Down Expand Up @@ -113,7 +124,10 @@ protected void addFields(TypeSpec.Builder type) {
.addField(SyncClientHandler.class, "clientHandler", PRIVATE, FINAL)
.addField(protocolSpec.protocolFactory(model))
.addField(SdkClientConfiguration.class, "clientConfiguration", PRIVATE, FINAL)
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL);
.addField(serviceClientConfigurationClassName, "serviceClientConfiguration", PRIVATE, FINAL)
.addField(ParameterizedTypeName.get(BiFunction.class, SdkRequest.class,
SdkClientConfiguration.class, SdkClientConfiguration.class),
"clientConfigurationForRequest", PRIVATE, FINAL);
}

@Override
Expand Down Expand Up @@ -176,6 +190,8 @@ private MethodSpec constructor() {
.addStatement("this.clientHandler = new $T(clientConfiguration)", protocolSpec.getClientHandlerClass())
.addStatement("this.clientConfiguration = clientConfiguration")
.addStatement("this.serviceClientConfiguration = serviceClientConfiguration");

builder.addCode(addConfigurationUpdater(configurationUtils.serviceClientConfigurationBuilderClassName()));
FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
if (model.getMetadata().isJsonProtocol()) {
builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
Expand Down Expand Up @@ -207,6 +223,26 @@ private MethodSpec constructor() {
return builder.build();
}

static CodeBlock addConfigurationUpdater(TypeName serviceClientConfigurationBuilderClassName) {
CodeBlock.Builder builder = CodeBlock.builder();
builder.add("$T configurationUpdater = ",
ParameterizedTypeName.get(ConfigurationUpdater.class, SdkServiceClientConfiguration.Builder.class));
builder.add("(consumer, configBuilder) -> {\n$>")
.addStatement("$1T.BuilderInternal serviceConfigBuilder = $1T.builder(configBuilder)",
serviceClientConfigurationBuilderClassName)
.addStatement("consumer.accept(serviceConfigBuilder)")
.addStatement("return serviceConfigBuilder.buildSdkClientConfiguration()")
.add("$<};\n");
builder.add("this.clientConfigurationForRequest = (request, config) -> {\n$>");
builder.addStatement("$T plugins = request.overrideConfiguration()\n"
+ ".map(c -> c.registeredPlugins()).orElse(Collections.emptyList())",
ParameterizedTypeName.get(List.class, SdkPlugin.class));
builder.addStatement("return $T.invokePlugins(config, plugins, configurationUpdater)", SdkClientConfigurationUtil.class);

builder.add("$<};\n");
return builder.build();
}

@Override
protected List<MethodSpec> operations() {
return model.getOperations().values().stream()
Expand Down Expand Up @@ -289,6 +325,8 @@ private MethodSpec traditionalMethod(OperationModel opModel) {
method.endControlFlow();
}

method.addStatement("$T clientConfiguration = this.clientConfigurationForRequest.apply($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
method.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)\n", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -259,6 +260,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(hasInitialRequestEvent(opModel, isRestJson))
.add(".withResponseHandler($L)\n", responseHandlerName(opModel, isRestJson))
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -156,6 +157,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(".withResponseHandler(responseHandler)\n")
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
hostPrefixExpression(opModel) +
discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down Expand Up @@ -198,11 +199,11 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
if (opModel.hasEventStreamOutput()) {
executionResponseTransformerName = "restAsyncResponseTransformer";
}

builder.add("\n\n$T<$T> executeFuture = clientHandler.execute(new $T<$T, $T>()\n",
CompletableFuture.class, executeFutureValueType,
ClientExecutionParams.class, requestType, pojoResponseType)
.add(".withOperationName(\"$N\")\n", opModel.getOperationName())
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMarshaller($L)\n", asyncMarshaller(intermediateModel, opModel, marshaller, "protocolFactory"));

if (opModel.hasEventStreamOutput()) {
Expand Down
Loading

0 comments on commit b4d643a

Please sign in to comment.