Skip to content

Commit

Permalink
Support non-Json String payloads (#4450)
Browse files Browse the repository at this point in the history
* Support String payloads for Json protocol

* Changelog

* Fix Spotbugs error

* Refactoring

* Remove unused import
  • Loading branch information
davidh44 authored Sep 21, 2023
1 parent 3ca853c commit 9eb921e
Show file tree
Hide file tree
Showing 13 changed files with 482 additions and 16 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-36670dc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "AWS SDK for Java v2",
"contributor": "",
"type": "feature",
"description": "Adds support for non-Json String payloads"
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ private static boolean isBlobShape(Shape shape) {
return shape != null && "blob".equals(shape.getType());
}

/**
* @return True if shape is a String type. False otherwise
*/
private static boolean isStringShape(Shape shape) {
return shape != null && "String".equals(shape.getType());
}

/**
* If there is a member in the output shape that is explicitly marked as the payload (with the
* payload trait) this method returns the target shape of that member. Otherwise this method
Expand Down Expand Up @@ -192,6 +199,9 @@ public Map<String, OperationModel> constructOperations() {
if (isBlobShape(getPayloadShape(c2jShapes, outputShape))) {
operationModel.setHasBlobMemberAsPayload(true);
}
if (isStringShape(getPayloadShape(c2jShapes, outputShape))) {
operationModel.setHasStringMemberAsPayload(true);
}
}

if (op.getErrors() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public class OperationModel extends DocumentationModel {

private boolean hasBlobMemberAsPayload;

private boolean hasStringMemberAsPayload;

private boolean isAuthenticated = true;

private AuthType authType;
Expand Down Expand Up @@ -211,6 +213,14 @@ public void setHasBlobMemberAsPayload(boolean hasBlobMemberAsPayload) {
this.hasBlobMemberAsPayload = hasBlobMemberAsPayload;
}

public boolean getHasStringMemberAsPayload() {
return this.hasStringMemberAsPayload;
}

public void setHasStringMemberAsPayload(boolean hasStringMemberAsPayload) {
this.hasStringMemberAsPayload = hasStringMemberAsPayload;
}

public boolean hasStreamingInput() {
return inputShape != null && inputShape.isHasStreamingMember();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public CodeBlock responseHandler(IntermediateModel model, OperationModel opModel
CodeBlock.builder()
.add("$T operationMetadata = $T.builder()\n", JsonOperationMetadata.class, JsonOperationMetadata.class)
.add(".hasStreamingSuccessResponse($L)\n", opModel.hasStreamingOutput())
.add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload())
.add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload() && !opModel.getHasStringMemberAsPayload())
.add(".build();");

if (opModel.hasEventStreamOutput()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.ByteArrayInputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.EnumMap;
Expand Down Expand Up @@ -182,6 +183,11 @@ void doMarshall(SdkPojo pojo) {
if (val != null) {
request.contentStreamProvider(((SdkBytes) val)::asInputStream);
}
} else if (isExplicitStringPayload(field)) {
if (val != null) {
byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8);
request.contentStreamProvider(() -> new ByteArrayInputStream(content));
}
} else if (isExplicitPayloadMember(field)) {
marshallExplicitJsonPayload(field, val);
} else {
Expand All @@ -194,6 +200,10 @@ private boolean isExplicitBinaryPayload(SdkField<?> field) {
return isExplicitPayloadMember(field) && MarshallingType.SDK_BYTES.equals(field.marshallingType());
}

private boolean isExplicitStringPayload(SdkField<?> field) {
return isExplicitPayloadMember(field) && MarshallingType.STRING.equals(field.marshallingType());
}

private boolean isExplicitPayloadMember(SdkField<?> field) {
return field.containsTrait(PayloadTrait.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ public T unmarshall(JsonUnmarshallerContext context,

public <TypeT extends SdkPojo> TypeT unmarshall(SdkPojo sdkPojo,
SdkHttpFullResponse response) throws IOException {
if (hasPayloadMembersOnUnmarshall(sdkPojo) && !hasExplicitBlobPayloadMember(sdkPojo) && response.content().isPresent()) {
if (hasPayloadMembersOnUnmarshall(sdkPojo)
&& !hasExplicitBlobPayloadMember(sdkPojo)
&& !hasExplicitStringPayloadMember(sdkPojo)
&& response.content().isPresent()) {
JsonNode jsonNode = parser.parse(response.content().get());
return unmarshall(sdkPojo, response, jsonNode);
} else {
Expand All @@ -201,6 +204,12 @@ private boolean hasExplicitBlobPayloadMember(SdkPojo sdkPojo) {
.anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.SDK_BYTES);
}

private boolean hasExplicitStringPayloadMember(SdkPojo sdkPojo) {
return sdkPojo.sdkFields()
.stream()
.anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.STRING);
}

private static boolean isExplicitPayloadMember(SdkField<?> f) {
return f.containsTrait(PayloadTrait.class);
}
Expand Down Expand Up @@ -234,6 +243,13 @@ private static <TypeT extends SdkPojo> TypeT unmarshallStructured(SdkPojo sdkPoj
} else {
field.set(sdkPojo, SdkBytes.fromByteArrayUnsafe(new byte[0]));
}
} else if (isExplicitPayloadMember(field) && field.marshallingType() == MarshallingType.STRING) {
Optional<AbortableInputStream> responseContent = context.response().content();
if (responseContent.isPresent()) {
field.set(sdkPojo, SdkBytes.fromInputStream(responseContent.get()).asUtf8String());
} else {
field.set(sdkPojo, "");
}
} else {
JsonNode jsonFieldContent = getJsonNode(jsonContent, field);
JsonUnmarshaller<Object> unmarshaller = context.getUnmarshaller(field.location(), field.marshallingType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@
"input":{"shape":"OperationWithExplicitPayloadBlobInput"},
"output":{"shape":"OperationWithExplicitPayloadBlobInput"}
},
"OperationWithExplicitPayloadString":{
"name":"OperationWithExplicitPayloadString",
"http":{
"method":"POST",
"requestUri":"/2016-03-11/operationWithExplicitPayloadString"
},
"input":{"shape":"OperationWithExplicitPayloadStringInput"},
"output":{"shape":"OperationWithExplicitPayloadStringInput"}
},
"OperationWithExplicitPayloadStructure":{
"name":"OperationWithExplicitPayloadStructure",
"http":{
Expand Down Expand Up @@ -697,6 +706,13 @@
},
"payload":"PayloadMember"
},
"OperationWithExplicitPayloadStringInput":{
"type":"structure",
"members":{
"PayloadMember":{"shape":"String"}
},
"payload":"PayloadMember"
},
"OperationWithExplicitPayloadStructureInput":{
"type":"structure",
"members":{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.protocolrestjson;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
import static com.github.tomakehurst.wiremock.client.WireMock.post;
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static org.assertj.core.api.Assertions.assertThat;

import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;

@WireMockTest
public class StringPayloadUnmarshallingTest {
private static final String TEST_PAYLOAD = "X";

private static List<Arguments> testParameters() {
List<Arguments> testCases = new ArrayList<>();
for (ClientType clientType : ClientType.values()) {
for (Protocol protocol : Protocol.values()) {
for (StringLocation value : StringLocation.values()) {
for (ContentLength contentLength : ContentLength.values()) {
testCases.add(Arguments.arguments(clientType, protocol, value, contentLength));
}
}
}
}
return testCases;
}

private enum ClientType {
SYNC,
ASYNC
}

private enum Protocol {
JSON
// TODO - add support for XML
}

private enum StringLocation {
PAYLOAD,
FIELD
}

private enum ContentLength {
ZERO,
NOT_PRESENT
}

@ParameterizedTest
@MethodSource("testParameters")
public void missingStringPayload_unmarshalledCorrectly(ClientType clientType,
Protocol protocol,
StringLocation stringLoc,
ContentLength contentLength,
WireMockRuntimeInfo wm) {
if (contentLength == ContentLength.ZERO) {
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("Content-Length", "0").withBody("")));
} else if (contentLength == ContentLength.NOT_PRESENT) {
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody("")));
}

String serviceResult = callService(wm, clientType, protocol, stringLoc);

if (stringLoc == StringLocation.PAYLOAD) {
assertThat(serviceResult).isNotNull().isEqualTo("");
} else if (stringLoc == StringLocation.FIELD) {
assertThat(serviceResult).isNull();
}
}

@ParameterizedTest
@MethodSource("testParameters")
public void presentStringPayload_unmarshalledCorrectly(ClientType clientType,
Protocol protocol,
StringLocation stringLoc,
ContentLength contentLength,
WireMockRuntimeInfo wm) {
String responsePayload = presentStringResponse(protocol, stringLoc);

if (contentLength == ContentLength.ZERO) {
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200)
.withHeader("Content-Length", Integer.toString(responsePayload.length()))
.withBody(responsePayload)));
} else if (contentLength == ContentLength.NOT_PRESENT) {
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(responsePayload)));
}

assertThat(callService(wm, clientType, protocol, stringLoc)).isEqualTo(TEST_PAYLOAD);
}

private String presentStringResponse(Protocol protocol, StringLocation stringLoc) {
switch (stringLoc) {
case PAYLOAD: return TEST_PAYLOAD;
case FIELD:
switch (protocol) {
case JSON: return "{\"StringMember\": \"X\"}";
// TODO - add support for XML
default: throw new UnsupportedOperationException();
}
default: throw new UnsupportedOperationException();
}

}

private String callService(WireMockRuntimeInfo wm, ClientType clientType, Protocol protocol, StringLocation stringLoc) {
switch (clientType) {
case SYNC: return syncCallService(wm, protocol, stringLoc);
case ASYNC: return asyncCallService(wm, protocol, stringLoc);
default: throw new UnsupportedOperationException();
}
}

private String syncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) {
switch (protocol) {
case JSON: return syncJsonCallService(wm, stringLoc);
// TODO - add support for XML
default: throw new UnsupportedOperationException();
}
}

private String asyncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) {
switch (protocol) {
case JSON: return asyncJsonCallService(wm, stringLoc);
// TODO - add support for XML
default: throw new UnsupportedOperationException();
}
}

private String syncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) {
ProtocolRestJsonClient client =
ProtocolRestJsonClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.region(Region.US_EAST_1)
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
.build();
switch (stringLoc) {
case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).payloadMember();
case FIELD: return client.allTypes(r -> {}).stringMember();
default: throw new UnsupportedOperationException();
}
}

private String asyncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) {
ProtocolRestJsonAsyncClient client =
ProtocolRestJsonAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.region(Region.US_EAST_1)
.endpointOverride(URI.create(wm.getHttpBaseUrl()))
.build();

switch (stringLoc) {
case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).join().payloadMember();
case FIELD: return client.allTypes(r -> {}).join().stringMember();
default: throw new UnsupportedOperationException();
}
}
}
Loading

0 comments on commit 9eb921e

Please sign in to comment.