Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request level plugin execution #4518

Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkServiceClientConfiguration;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.client.config.internal.ConfigurationUpdater;
import software.amazon.awssdk.core.client.config.internal.SdkClientConfigurationUtil;
import software.amazon.awssdk.core.client.handler.AsyncClientHandler;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache;
import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest;
Expand All @@ -96,6 +102,7 @@ public final class AsyncClientClass extends AsyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public AsyncClientClass(GeneratorTaskParams dependencies) {
Expand All @@ -106,6 +113,7 @@ public AsyncClientClass(GeneratorTaskParams dependencies) {
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
}

@Override
Expand Down Expand Up @@ -165,7 +173,7 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
type.addMethod(isSignerOverriddenOnClientMethod());
}
}

type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName()));
protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod);
}

Expand Down Expand Up @@ -278,6 +286,28 @@ protected MethodSpec serviceClientConfigMethod() {
.build();
}

protected static MethodSpec updateSdkClientConfigurationMethod(TypeName serviceClientConfigurationBuilderClassName) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("updateSdkClientConfiguration")
.addModifiers(PROTECTED)
.addParameter(SdkRequest.class, "request")
.addParameter(SdkClientConfiguration.class, "clientConfiguration")
.returns(SdkClientConfiguration.class);
builder.addCode("$T configurationUpdater = ",
ParameterizedTypeName.get(ConfigurationUpdater.class, SdkServiceClientConfiguration.Builder.class));
builder.addCode("(consumer, configBuilder) -> {\n$>")
.addStatement("$1T.BuilderInternal serviceConfigBuilder = $1T.builder(configBuilder)",
serviceClientConfigurationBuilderClassName)
.addStatement("consumer.accept(serviceConfigBuilder)")
.addStatement("return serviceConfigBuilder.buildSdkClientConfiguration()")
.addCode("$<};\n");
builder.addStatement("$T plugins = request.overrideConfiguration()\n"
+ ".map(c -> c.plugins()).orElse(Collections.emptyList())",
ParameterizedTypeName.get(List.class, SdkPlugin.class));
builder.addStatement("return $T.invokePlugins(clientConfiguration, plugins, configurationUpdater)",
SdkClientConfigurationUtil.class);
return builder.build();
}

@Override
protected void addCloseMethod(TypeSpec.Builder type) {
MethodSpec method = MethodSpec.methodBuilder("close")
Expand All @@ -294,7 +324,8 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation

builder.addModifiers(PUBLIC)
.addAnnotation(Override.class);

builder.addStatement("$T clientConfiguration = updateSdkClientConfiguration($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
builder.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static javax.lang.model.element.Modifier.PROTECTED;
import static javax.lang.model.element.Modifier.PUBLIC;
import static javax.lang.model.element.Modifier.STATIC;
import static software.amazon.awssdk.codegen.poet.client.AsyncClientClass.updateSdkClientConfigurationMethod;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.addS3ArnableFieldCode;
import static software.amazon.awssdk.codegen.poet.client.ClientClassUtils.applySignerOverrideMethod;

Expand Down Expand Up @@ -52,6 +53,7 @@
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
Expand All @@ -73,6 +75,7 @@ public class SyncClientClass extends SyncClientInterface {
private final ClassName className;
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final boolean useSraAuth;

public SyncClientClass(GeneratorTaskParams taskParams) {
Expand All @@ -82,6 +85,7 @@ public SyncClientClass(GeneratorTaskParams taskParams) {
this.className = poetExtensions.getClientClass(model.getMetadata().getSyncClient());
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
}

Expand Down Expand Up @@ -133,7 +137,7 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
.addMethod(resolveMetricPublishersMethod());

protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod);

type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName()));
type.addMethod(protocolSpec.initProtocolFactory(model));
}

Expand Down Expand Up @@ -176,6 +180,7 @@ private MethodSpec constructor() {
.addStatement("this.clientHandler = new $T(clientConfiguration)", protocolSpec.getClientHandlerClass())
.addStatement("this.clientConfiguration = clientConfiguration")
.addStatement("this.serviceClientConfiguration = serviceClientConfiguration");

FieldSpec protocolFactoryField = protocolSpec.protocolFactory(model);
if (model.getMetadata().isJsonProtocol()) {
builder.addStatement("this.$N = init($T.builder()).build()", protocolFactoryField.name,
Expand Down Expand Up @@ -289,6 +294,8 @@ private MethodSpec traditionalMethod(OperationModel opModel) {
method.endControlFlow();
}

method.addStatement("$T clientConfiguration = updateSdkClientConfiguration($L, this.clientConfiguration)",
SdkClientConfiguration.class, opModel.getInput().getVariableName());
method.addStatement("$T<$T> metricPublishers = "
+ "resolveMetricPublishers(clientConfiguration, $N.overrideConfiguration().orElse(null))",
List.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
.addMethod(serviceMetadata());

PoetUtils.addJavadoc(type::addJavadoc, getJavadoc());

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)\n", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -259,6 +260,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(hasInitialRequestEvent(opModel, isRestJson))
.add(".withResponseHandler($L)\n", responseHandlerName(opModel, isRestJson))
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
Expand Down Expand Up @@ -156,6 +157,7 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(".withResponseHandler(responseHandler)\n")
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public CodeBlock executionHandler(OperationModel opModel) {
hostPrefixExpression(opModel) +
discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down Expand Up @@ -198,11 +199,11 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
if (opModel.hasEventStreamOutput()) {
executionResponseTransformerName = "restAsyncResponseTransformer";
}

builder.add("\n\n$T<$T> executeFuture = clientHandler.execute(new $T<$T, $T>()\n",
CompletableFuture.class, executeFutureValueType,
ClientExecutionParams.class, requestType, pojoResponseType)
.add(".withOperationName(\"$N\")\n", opModel.getOperationName())
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMarshaller($L)\n", asyncMarshaller(intermediateModel, opModel, marshaller, "protocolFactory"));

if (opModel.hasEventStreamOutput()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import software.amazon.awssdk.codegen.model.intermediate.ShapeType;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.util.DefaultSdkAutoConstructList;
import software.amazon.awssdk.core.util.DefaultSdkAutoConstructMap;
Expand Down Expand Up @@ -113,14 +112,6 @@ public TypeSpec builderInterface() {
"builderConsumer")
.addModifiers(PUBLIC, Modifier.ABSTRACT)
.build());

builder.addMethod(MethodSpec.methodBuilder("addPlugin")
.addAnnotation(Override.class)
.returns(builderInterfaceName())
.addParameter(SdkPlugin.class , "plugin")
.addModifiers(PUBLIC, Modifier.ABSTRACT)
.build());

}

return builder.build();
Expand Down Expand Up @@ -279,15 +270,6 @@ private List<MethodSpec> accessors() {
.addStatement("super.overrideConfiguration(builderConsumer)")
.addStatement("return this")
.build());

accessors.add(MethodSpec.methodBuilder("addPlugin")
.addAnnotation(Override.class)
.returns(builderInterfaceName())
.addParameter(SdkPlugin.class, "plugin")
.addModifiers(PUBLIC)
.addStatement("super.addPlugin(plugin)")
.addStatement("return this")
.build());
}

return accessors;
Expand Down
Loading
Loading