Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add x-amzn-query-mode to http header when service has @awsQueryCompatible trait #5707

Merged
merged 8 commits into from
Nov 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ public final ProtocolMarshaller<SdkHttpFullRequest> createProtocolMarshaller(Ope
.operationInfo(operationInfo)
.sendExplicitNullForPayload(false)
.protocolMetadata(protocolMetadata)
.hasAwsQueryCompatible(hasAwsQueryCompatible)
.build();
}
Fred1155 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ public class JsonProtocolMarshaller implements ProtocolMarshaller<SdkHttpFullReq
private final JsonMarshallerContext marshallerContext;
private final boolean hasEventStreamingInput;
private final boolean hasEvent;
private final boolean hasAwsQueryCompatible;

JsonProtocolMarshaller(URI endpoint,
StructuredJsonGenerator jsonGenerator,
String contentType,
OperationInfo operationInfo,
AwsJsonProtocolMetadata protocolMetadata) {
AwsJsonProtocolMetadata protocolMetadata,
boolean hasAwsQueryCompatible) {
this.endpoint = endpoint;
this.jsonGenerator = jsonGenerator;
this.contentType = contentType;
Expand All @@ -88,6 +90,7 @@ public class JsonProtocolMarshaller implements ProtocolMarshaller<SdkHttpFullReq
this.hasEventStreamingInput = operationInfo.hasEventStreamingInput();
this.hasEvent = operationInfo.hasEvent();
this.request = fillBasicRequestParams(operationInfo);
this.hasAwsQueryCompatible = hasAwsQueryCompatible;
this.marshallerContext = JsonMarshallerContext.builder()
.jsonGenerator(jsonGenerator)
.marshallerRegistry(MARSHALLER_REGISTRY)
Expand Down Expand Up @@ -292,6 +295,10 @@ private SdkHttpFullRequest finishMarshalling() {
}
}

if (hasAwsQueryCompatible) {
request.putHeader("x-amzn-query-mode", "true");
}

return request.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public final class JsonProtocolMarshallerBuilder {
private OperationInfo operationInfo;
private boolean sendExplicitNullForPayload;
private AwsJsonProtocolMetadata protocolMetadata;
private boolean hasAwsQueryCompatible = false;

private JsonProtocolMarshallerBuilder() {
}
Expand Down Expand Up @@ -102,6 +103,14 @@ public JsonProtocolMarshallerBuilder protocolMetadata(AwsJsonProtocolMetadata pr
return this;
}

/**
* @param hasAwsQueryCompatible True if the service is AWS Query compatible, (has the @awsQueryCompatible trait)
*/
public JsonProtocolMarshallerBuilder hasAwsQueryCompatible(boolean hasAwsQueryCompatible) {
this.hasAwsQueryCompatible = hasAwsQueryCompatible;
return this;
}

/**
* @return New instance of {@link ProtocolMarshaller}. If {@link #sendExplicitNullForPayload} is true then the marshaller
* will be wrapped with {@link NullAsEmptyBodyProtocolRequestMarshaller}.
Expand All @@ -111,7 +120,8 @@ public ProtocolMarshaller<SdkHttpFullRequest> build() {
jsonGenerator,
contentType,
operationInfo,
protocolMetadata);
protocolMetadata,
hasAwsQueryCompatible);
return sendExplicitNullForPayload ? protocolMarshaller
: new NullAsEmptyBodyProtocolRequestMarshaller(protocolMarshaller);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.protocols.json;

import java.net.URI;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ClientEndpointProvider;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.protocols.core.OperationInfo;
import software.amazon.awssdk.protocols.core.ProtocolMarshaller;
import static org.assertj.core.api.Assertions.assertThat;

public class AWSQueryModeTest {

private static final OperationInfo EMPTY_OPERATION_INFO = OperationInfo.builder()
.httpMethod(SdkHttpMethod.POST)
.hasImplicitPayloadMembers(true)
.build();

/*
* A simple test POJO to marshall
*/
private static final class TestPojo implements SdkPojo {

private TestPojo() {}

@Override
public List<SdkField<?>> sdkFields() {
return Collections.emptyList();
}

@Override
public boolean equalsBySdkFields(Object other) {
if (!(other instanceof TestPojo)) {
return false;
}
return true;
}

@Override
public Map<String, SdkField<?>> sdkFieldNameToField() {
return Collections.emptyMap();
}

}

@Test
public void testMarshallWithAwsQueryCompatibleTrue() {
SdkClientConfiguration clientConfig =
SdkClientConfiguration.builder()
.option(SdkClientOption.CLIENT_ENDPOINT_PROVIDER,
ClientEndpointProvider.forEndpointOverride(URI.create("http://localhost")))
.build();
AwsJsonProtocolFactory factory =
AwsJsonProtocolFactory.builder()
.clientConfiguration(clientConfig)
.protocolVersion("1.1")
.protocol(AwsJsonProtocol.AWS_JSON)
.hasAwsQueryCompatible(true)
.build();

ProtocolMarshaller<SdkHttpFullRequest> marshaller = factory.createProtocolMarshaller(EMPTY_OPERATION_INFO);
SdkPojo testPojo = new TestPojo();

SdkHttpFullRequest result = marshaller.marshall(testPojo);

assertThat(result.headers()).containsKey("x-amzn-query-mode");
assertThat(result.headers().get("x-amzn-query-mode").get(0)).isEqualTo("true");
}

@Test
public void testMarshallWithNoAwsQueryCompatible() {
SdkClientConfiguration clientConfig =
SdkClientConfiguration.builder()
.option(SdkClientOption.CLIENT_ENDPOINT_PROVIDER,
ClientEndpointProvider.forEndpointOverride(URI.create("http://localhost")))
.build();
AwsJsonProtocolFactory factory =
AwsJsonProtocolFactory.builder()
.clientConfiguration(clientConfig)
.protocolVersion("1.1")
.protocol(AwsJsonProtocol.AWS_JSON)
.build();

ProtocolMarshaller<SdkHttpFullRequest> marshaller = factory.createProtocolMarshaller(EMPTY_OPERATION_INFO);
SdkPojo testPojo = new TestPojo();

SdkHttpFullRequest result = marshaller.marshall(testPojo);

assertThat(result.headers()).doesNotContainKey("x-amzn-query-mode");
}

}
Loading