diff --git a/.changes/next-release/feature-AWSSDKforJavav2-36670dc.json b/.changes/next-release/feature-AWSSDKforJavav2-36670dc.json new file mode 100644 index 000000000000..eab097e3f4fd --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-36670dc.json @@ -0,0 +1,6 @@ +{ + "category": "AWS SDK for Java v2", + "contributor": "", + "type": "feature", + "description": "Adds support for non-Json String payloads" +} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/AddOperations.java b/codegen/src/main/java/software/amazon/awssdk/codegen/AddOperations.java index 79bb81470f5e..d264a2035f7e 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/AddOperations.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/AddOperations.java @@ -69,6 +69,13 @@ private static boolean isBlobShape(Shape shape) { return shape != null && "blob".equals(shape.getType()); } + /** + * @return True if shape is a String type. False otherwise + */ + private static boolean isStringShape(Shape shape) { + return shape != null && "String".equals(shape.getType()); + } + /** * If there is a member in the output shape that is explicitly marked as the payload (with the * payload trait) this method returns the target shape of that member. Otherwise this method @@ -192,6 +199,9 @@ public Map constructOperations() { if (isBlobShape(getPayloadShape(c2jShapes, outputShape))) { operationModel.setHasBlobMemberAsPayload(true); } + if (isStringShape(getPayloadShape(c2jShapes, outputShape))) { + operationModel.setHasStringMemberAsPayload(true); + } } if (op.getErrors() != null) { diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/model/intermediate/OperationModel.java b/codegen/src/main/java/software/amazon/awssdk/codegen/model/intermediate/OperationModel.java index 1ff197191126..0510abf6e3be 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/model/intermediate/OperationModel.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/model/intermediate/OperationModel.java @@ -48,6 +48,8 @@ public class OperationModel extends DocumentationModel { private boolean hasBlobMemberAsPayload; + private boolean hasStringMemberAsPayload; + private boolean isAuthenticated = true; private AuthType authType; @@ -211,6 +213,14 @@ public void setHasBlobMemberAsPayload(boolean hasBlobMemberAsPayload) { this.hasBlobMemberAsPayload = hasBlobMemberAsPayload; } + public boolean getHasStringMemberAsPayload() { + return this.hasStringMemberAsPayload; + } + + public void setHasStringMemberAsPayload(boolean hasStringMemberAsPayload) { + this.hasStringMemberAsPayload = hasStringMemberAsPayload; + } + public boolean hasStreamingInput() { return inputShape != null && inputShape.isHasStreamingMember(); } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index 44922d4e2b32..ce46ecf7c824 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -142,7 +142,7 @@ public CodeBlock responseHandler(IntermediateModel model, OperationModel opModel CodeBlock.builder() .add("$T operationMetadata = $T.builder()\n", JsonOperationMetadata.class, JsonOperationMetadata.class) .add(".hasStreamingSuccessResponse($L)\n", opModel.hasStreamingOutput()) - .add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload()) + .add(".isPayloadJson($L)\n", !opModel.getHasBlobMemberAsPayload() && !opModel.getHasStringMemberAsPayload()) .add(".build();"); if (opModel.hasEventStreamOutput()) { diff --git a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java index 15b0c910d472..b81110bd1026 100644 --- a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java +++ b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java @@ -23,6 +23,7 @@ import java.io.ByteArrayInputStream; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; import java.util.EnumMap; @@ -182,6 +183,11 @@ void doMarshall(SdkPojo pojo) { if (val != null) { request.contentStreamProvider(((SdkBytes) val)::asInputStream); } + } else if (isExplicitStringPayload(field)) { + if (val != null) { + byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8); + request.contentStreamProvider(() -> new ByteArrayInputStream(content)); + } } else if (isExplicitPayloadMember(field)) { marshallExplicitJsonPayload(field, val); } else { @@ -194,6 +200,10 @@ private boolean isExplicitBinaryPayload(SdkField field) { return isExplicitPayloadMember(field) && MarshallingType.SDK_BYTES.equals(field.marshallingType()); } + private boolean isExplicitStringPayload(SdkField field) { + return isExplicitPayloadMember(field) && MarshallingType.STRING.equals(field.marshallingType()); + } + private boolean isExplicitPayloadMember(SdkField field) { return field.containsTrait(PayloadTrait.class); } diff --git a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java index 3ed0c9fb8cae..fe8bdc0f7015 100644 --- a/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java +++ b/core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/unmarshall/JsonProtocolUnmarshaller.java @@ -187,7 +187,10 @@ public T unmarshall(JsonUnmarshallerContext context, public TypeT unmarshall(SdkPojo sdkPojo, SdkHttpFullResponse response) throws IOException { - if (hasPayloadMembersOnUnmarshall(sdkPojo) && !hasExplicitBlobPayloadMember(sdkPojo) && response.content().isPresent()) { + if (hasPayloadMembersOnUnmarshall(sdkPojo) + && !hasExplicitBlobPayloadMember(sdkPojo) + && !hasExplicitStringPayloadMember(sdkPojo) + && response.content().isPresent()) { JsonNode jsonNode = parser.parse(response.content().get()); return unmarshall(sdkPojo, response, jsonNode); } else { @@ -201,6 +204,12 @@ private boolean hasExplicitBlobPayloadMember(SdkPojo sdkPojo) { .anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.SDK_BYTES); } + private boolean hasExplicitStringPayloadMember(SdkPojo sdkPojo) { + return sdkPojo.sdkFields() + .stream() + .anyMatch(f -> isExplicitPayloadMember(f) && f.marshallingType() == MarshallingType.STRING); + } + private static boolean isExplicitPayloadMember(SdkField f) { return f.containsTrait(PayloadTrait.class); } @@ -234,6 +243,13 @@ private static TypeT unmarshallStructured(SdkPojo sdkPoj } else { field.set(sdkPojo, SdkBytes.fromByteArrayUnsafe(new byte[0])); } + } else if (isExplicitPayloadMember(field) && field.marshallingType() == MarshallingType.STRING) { + Optional responseContent = context.response().content(); + if (responseContent.isPresent()) { + field.set(sdkPojo, SdkBytes.fromInputStream(responseContent.get()).asUtf8String()); + } else { + field.set(sdkPojo, ""); + } } else { JsonNode jsonFieldContent = getJsonNode(jsonContent, field); JsonUnmarshaller unmarshaller = context.getUnmarshaller(field.location(), field.marshallingType()); diff --git a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/customresponsemetadata/service-2.json b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/customresponsemetadata/service-2.json index 8cdb71614e38..11d12c84f6ce 100644 --- a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/customresponsemetadata/service-2.json +++ b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/customresponsemetadata/service-2.json @@ -111,6 +111,15 @@ "input":{"shape":"OperationWithExplicitPayloadBlobInput"}, "output":{"shape":"OperationWithExplicitPayloadBlobInput"} }, + "OperationWithExplicitPayloadString":{ + "name":"OperationWithExplicitPayloadString", + "http":{ + "method":"POST", + "requestUri":"/2016-03-11/operationWithExplicitPayloadString" + }, + "input":{"shape":"OperationWithExplicitPayloadStringInput"}, + "output":{"shape":"OperationWithExplicitPayloadStringInput"} + }, "OperationWithExplicitPayloadStructure":{ "name":"OperationWithExplicitPayloadStructure", "http":{ @@ -697,6 +706,13 @@ }, "payload":"PayloadMember" }, + "OperationWithExplicitPayloadStringInput":{ + "type":"structure", + "members":{ + "PayloadMember":{"shape":"String"} + }, + "payload":"PayloadMember" + }, "OperationWithExplicitPayloadStructureInput":{ "type":"structure", "members":{ diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/StringPayloadUnmarshallingTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/StringPayloadUnmarshallingTest.java new file mode 100644 index 000000000000..27590c5a20cf --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/protocolrestjson/StringPayloadUnmarshallingTest.java @@ -0,0 +1,182 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.protocolrestjson; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +@WireMockTest +public class StringPayloadUnmarshallingTest { + private static final String TEST_PAYLOAD = "X"; + + private static List testParameters() { + List testCases = new ArrayList<>(); + for (ClientType clientType : ClientType.values()) { + for (Protocol protocol : Protocol.values()) { + for (StringLocation value : StringLocation.values()) { + for (ContentLength contentLength : ContentLength.values()) { + testCases.add(Arguments.arguments(clientType, protocol, value, contentLength)); + } + } + } + } + return testCases; + } + + private enum ClientType { + SYNC, + ASYNC + } + + private enum Protocol { + JSON + // TODO - add support for XML + } + + private enum StringLocation { + PAYLOAD, + FIELD + } + + private enum ContentLength { + ZERO, + NOT_PRESENT + } + + @ParameterizedTest + @MethodSource("testParameters") + public void missingStringPayload_unmarshalledCorrectly(ClientType clientType, + Protocol protocol, + StringLocation stringLoc, + ContentLength contentLength, + WireMockRuntimeInfo wm) { + if (contentLength == ContentLength.ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withHeader("Content-Length", "0").withBody(""))); + } else if (contentLength == ContentLength.NOT_PRESENT) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(""))); + } + + String serviceResult = callService(wm, clientType, protocol, stringLoc); + + if (stringLoc == StringLocation.PAYLOAD) { + assertThat(serviceResult).isNotNull().isEqualTo(""); + } else if (stringLoc == StringLocation.FIELD) { + assertThat(serviceResult).isNull(); + } + } + + @ParameterizedTest + @MethodSource("testParameters") + public void presentStringPayload_unmarshalledCorrectly(ClientType clientType, + Protocol protocol, + StringLocation stringLoc, + ContentLength contentLength, + WireMockRuntimeInfo wm) { + String responsePayload = presentStringResponse(protocol, stringLoc); + + if (contentLength == ContentLength.ZERO) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200) + .withHeader("Content-Length", Integer.toString(responsePayload.length())) + .withBody(responsePayload))); + } else if (contentLength == ContentLength.NOT_PRESENT) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(responsePayload))); + } + + assertThat(callService(wm, clientType, protocol, stringLoc)).isEqualTo(TEST_PAYLOAD); + } + + private String presentStringResponse(Protocol protocol, StringLocation stringLoc) { + switch (stringLoc) { + case PAYLOAD: return TEST_PAYLOAD; + case FIELD: + switch (protocol) { + case JSON: return "{\"StringMember\": \"X\"}"; + // TODO - add support for XML + default: throw new UnsupportedOperationException(); + } + default: throw new UnsupportedOperationException(); + } + + } + + private String callService(WireMockRuntimeInfo wm, ClientType clientType, Protocol protocol, StringLocation stringLoc) { + switch (clientType) { + case SYNC: return syncCallService(wm, protocol, stringLoc); + case ASYNC: return asyncCallService(wm, protocol, stringLoc); + default: throw new UnsupportedOperationException(); + } + } + + private String syncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) { + switch (protocol) { + case JSON: return syncJsonCallService(wm, stringLoc); + // TODO - add support for XML + default: throw new UnsupportedOperationException(); + } + } + + private String asyncCallService(WireMockRuntimeInfo wm, Protocol protocol, StringLocation stringLoc) { + switch (protocol) { + case JSON: return asyncJsonCallService(wm, stringLoc); + // TODO - add support for XML + default: throw new UnsupportedOperationException(); + } + } + + private String syncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) { + ProtocolRestJsonClient client = + ProtocolRestJsonClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + switch (stringLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).payloadMember(); + case FIELD: return client.allTypes(r -> {}).stringMember(); + default: throw new UnsupportedOperationException(); + } + } + + private String asyncJsonCallService(WireMockRuntimeInfo wm, StringLocation stringLoc) { + ProtocolRestJsonAsyncClient client = + ProtocolRestJsonAsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"))) + .region(Region.US_EAST_1) + .endpointOverride(URI.create(wm.getHttpBaseUrl())) + .build(); + + switch (stringLoc) { + case PAYLOAD: return client.operationWithExplicitPayloadString(r -> {}).join().payloadMember(); + case FIELD: return client.allTypes(r -> {}).join().stringMember(); + default: throw new UnsupportedOperationException(); + } + } +} diff --git a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-contenttype.json b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-contenttype.json index b82b0f70e93b..7260c5743e6b 100644 --- a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-contenttype.json +++ b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-contenttype.json @@ -138,6 +138,33 @@ } } }, +{ + "description": "TestStringPayload", + "given": { + "input": { + "data": "1234", + "contentType": "text/plain" + } + }, + "when": { + "action": "marshall", + "operation": "TestStringPayload" + }, + "then": { + "serializedAs": { + "uri": "/string-payload", + "method": "POST", + "headers": { + "contains": { + "Content-Type": "text/plain" + } + }, + "body": { + "equals": "1234" + } + } + } +}, { "description": "TestBlobPayloadNoParams", "given": { @@ -163,6 +190,31 @@ } } }, +{ + "description": "TestStringPayloadNoParams", + "given": { + "input": { + } + }, + "when": { + "action": "marshall", + "operation": "TestStringPayload" + }, + "then": { + "serializedAs": { + "uri": "/string-payload", + "method": "POST", + "headers": { + "doesNotContain": [ + "Content-Type" + ] + }, + "body": { + "equals": "" + } + } + } +}, { "description": "NoPayload", "given": { diff --git a/test/protocol-tests/src/main/resources/codegen-resources/restjson/contenttype/service-2.json b/test/protocol-tests/src/main/resources/codegen-resources/restjson/contenttype/service-2.json index 717cc54e7d4f..65d3bc977488 100644 --- a/test/protocol-tests/src/main/resources/codegen-resources/restjson/contenttype/service-2.json +++ b/test/protocol-tests/src/main/resources/codegen-resources/restjson/contenttype/service-2.json @@ -36,6 +36,14 @@ }, "input": {"shape": "TestBlobPayloadRequest"} }, + "TestStringPayload": { + "name": "TestStringPayload", + "http": { + "method": "POST", + "requestUri": "/string-payload" + }, + "input": {"shape": "TestStringPayloadRequest"} + }, "NoPayload": { "name": "NoPayload", "http": { @@ -173,6 +181,24 @@ "documentation":"

The request structure for a blob payload request.

", "payload":"data" }, + "TestStringPayloadRequest":{ + "type":"structure", + "required":[], + "members":{ + "data":{ + "shape":"String", + "documentation":"

String payload to post

" + }, + "contentType":{ + "shape":"String", + "documentation":"

Optional content-type header

", + "location":"header", + "locationName":"Content-Type" + } + }, + "documentation":"

The request structure for a String payload request.

", + "payload":"data" + }, "TestEventStreamRequest": { "type": "structure", "required": [ @@ -191,6 +217,9 @@ "BlobAndHeadersEvent": { "shape": "BlobAndHeadersEvent" }, + "StringAndHeadersEvent": { + "shape": "StringAndHeadersEvent" + }, "HeadersOnlyEvent": { "shape": "HeadersOnlyEvent" }, @@ -214,6 +243,20 @@ }, "event": true }, + "StringAndHeadersEvent": { + "type": "structure", + "members": { + "StringPayloadMember": { + "shape":"String", + "eventpayload":true + }, + "HeaderMember": { + "shape": "String", + "eventheader": true + } + }, + "event": true + }, "ImplicitPayloadAndHeadersEvent": { "type": "structure", "members": { diff --git a/test/protocol-tests/src/main/resources/codegen-resources/restjson/service-2.json b/test/protocol-tests/src/main/resources/codegen-resources/restjson/service-2.json index 79cc7f20ce35..42ca767245a5 100644 --- a/test/protocol-tests/src/main/resources/codegen-resources/restjson/service-2.json +++ b/test/protocol-tests/src/main/resources/codegen-resources/restjson/service-2.json @@ -131,6 +131,15 @@ "input":{"shape":"OperationWithExplicitPayloadBlobInput"}, "output":{"shape":"OperationWithExplicitPayloadBlobInput"} }, + "OperationWithExplicitPayloadString":{ + "name":"OperationWithExplicitPayloadString", + "http":{ + "method":"POST", + "requestUri":"/2016-03-11/operationWithExplicitPayloadString" + }, + "input":{"shape":"OperationWithExplicitPayloadStringInput"}, + "output":{"shape":"OperationWithExplicitPayloadStringInput"} + }, "OperationWithExplicitPayloadStructure":{ "name":"OperationWithExplicitPayloadStructure", "http":{ @@ -217,6 +226,19 @@ "shape": "EventStreamOutput" } }, + "EventStreamStringPayloadOperation": { + "name": "EventStreamStringPayloadOperation", + "http": { + "method": "POST", + "requestUri": "/2016-03-11/eventStreamStringPayloadOperation" + }, + "input": { + "shape": "EventStreamStringPayloadOperationRequest" + }, + "output": { + "shape": "EventStreamOutput" + } + }, "DocumentInputOperation":{ "name":"DocumentInputOperation", "http":{ @@ -669,6 +691,13 @@ }, "payload":"PayloadMember" }, + "OperationWithExplicitPayloadStringInput":{ + "type":"structure", + "members":{ + "PayloadMember":{"shape":"String"} + }, + "payload":"PayloadMember" + }, "OperationWithExplicitPayloadStructureInput":{ "type":"structure", "members":{ @@ -815,6 +844,18 @@ }, "payload":"InputEventStream" }, + "EventStreamStringPayloadOperationRequest": { + "type": "structure", + "required": [ + "InputEventStream" + ], + "members": { + "InputEventStreamStringPayload": { + "shape": "InputEventStreamStringPayload" + } + }, + "payload":"InputEventStreamStringPayload" + }, "EventStreamOutput": { "type": "structure", "required": [ @@ -835,6 +876,15 @@ }, "eventstream": true }, + "InputEventStreamStringPayload": { + "type": "structure", + "members": { + "InputEvent": { + "shape": "InputEventStringPayload" + } + }, + "eventstream": true + }, "InputEvent": { "type": "structure", "members": { @@ -849,6 +899,20 @@ }, "event": true }, + "InputEventStringPayload": { + "type": "structure", + "members": { + "ExplicitPayloadStringMember": { + "shape":"String", + "eventpayload":true + }, + "HeaderMember": { + "shape": "String", + "eventheader": true + } + }, + "event": true + }, "ExplicitPayloadMember":{"type":"blob"}, "EventStream": { "type": "structure", diff --git a/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/EventTransformTest.java b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/EventTransformTest.java index cc7bc36ff4e2..93b89c178064 100644 --- a/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/EventTransformTest.java +++ b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/EventTransformTest.java @@ -19,7 +19,8 @@ import java.net.URI; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; @@ -33,13 +34,16 @@ import software.amazon.awssdk.protocols.json.AwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.services.protocolrestjson.model.InputEvent; +import software.amazon.awssdk.services.protocolrestjson.model.InputEventStringPayload; import software.amazon.awssdk.services.protocolrestjson.transform.InputEventMarshaller; +import software.amazon.awssdk.services.protocolrestjson.transform.InputEventStringPayloadMarshaller; /** * Marshalling and Unmarshalling tests for events. */ public class EventTransformTest { - private static final String EXPLICIT_PAYLOAD = "{\"ExplicitPayloadMember\": \"bar\"}"; + private static final String EXPLICIT_PAYLOAD_JSON = "{\"ExplicitPayloadMember\": \"bar\"}"; + private static final String EXPLICIT_PAYLOAD_NON_JSON = "bar"; private static final String HEADER_MEMBER_NAME = "HeaderMember"; private static final String HEADER_MEMBER = "foo"; private static AwsJsonProtocolFactory protocolFactory; @@ -55,34 +59,69 @@ public static void setup() { .build(); } - @Test - public void testUnmarshalling() throws Exception { + @ParameterizedTest + @ValueSource(strings = {EXPLICIT_PAYLOAD_JSON, EXPLICIT_PAYLOAD_NON_JSON}) + public void testUnmarshalling_BlobPayload(String payload) throws Exception { HttpResponseHandler responseHandler = protocolFactory - .createResponseHandler(JsonOperationMetadata.builder().build(), InputEvent::builder); + .createResponseHandler(JsonOperationMetadata.builder().build(), InputEvent::builder); InputEvent unmarshalled = (InputEvent) responseHandler.handle(SdkHttpFullResponse.builder() - .content(AbortableInputStream.create(SdkBytes.fromUtf8String(EXPLICIT_PAYLOAD).asInputStream())) - .putHeader(HEADER_MEMBER_NAME, HEADER_MEMBER) - .build(), - new ExecutionAttributes()); + .content(AbortableInputStream.create(SdkBytes.fromUtf8String(payload).asInputStream())) + .putHeader(HEADER_MEMBER_NAME, HEADER_MEMBER) + .build(), + new ExecutionAttributes()); assertThat(unmarshalled.headerMember()).isEqualTo(HEADER_MEMBER); - assertThat(unmarshalled.explicitPayloadMember().asUtf8String()).isEqualTo(EXPLICIT_PAYLOAD); + assertThat(unmarshalled.explicitPayloadMember().asUtf8String()).isEqualTo(payload); } - @Test - public void testMarshalling() { + @ParameterizedTest + @ValueSource(strings = {EXPLICIT_PAYLOAD_JSON, EXPLICIT_PAYLOAD_NON_JSON}) + public void testUnmarshalling_StringPayload(String payload) throws Exception { + HttpResponseHandler responseHandler = protocolFactory + .createResponseHandler(JsonOperationMetadata.builder().build(), InputEventStringPayload::builder); + + InputEventStringPayload unmarshalled = (InputEventStringPayload) responseHandler.handle(SdkHttpFullResponse.builder() + .content(AbortableInputStream.create(SdkBytes.fromUtf8String(payload).asInputStream())) + .putHeader(HEADER_MEMBER_NAME, HEADER_MEMBER) + .build(), + new ExecutionAttributes()); + + assertThat(unmarshalled.headerMember()).isEqualTo(HEADER_MEMBER); + assertThat(unmarshalled.explicitPayloadStringMember()).isEqualTo(payload); + } + + @ParameterizedTest + @ValueSource(strings = {EXPLICIT_PAYLOAD_JSON, EXPLICIT_PAYLOAD_NON_JSON}) + public void testMarshalling_BlobPayload(String payload) { InputEventMarshaller marshaller = new InputEventMarshaller(protocolFactory); InputEvent e = InputEvent.builder() .headerMember(HEADER_MEMBER) - .explicitPayloadMember(SdkBytes.fromUtf8String(EXPLICIT_PAYLOAD)) + .explicitPayloadMember(SdkBytes.fromUtf8String(payload)) .build(); SdkHttpFullRequest marshalled = marshaller.marshall(e); assertThat(marshalled.headers().get(HEADER_MEMBER_NAME)).containsExactly(HEADER_MEMBER); assertThat(marshalled.contentStreamProvider().get().newStream()) - .hasSameContentAs(SdkBytes.fromUtf8String(EXPLICIT_PAYLOAD).asInputStream()); + .hasSameContentAs(SdkBytes.fromUtf8String(payload).asInputStream()); + } + + @ParameterizedTest + @ValueSource(strings = {EXPLICIT_PAYLOAD_JSON, EXPLICIT_PAYLOAD_NON_JSON}) + public void testMarshalling_StringPayload(String payload) { + InputEventStringPayloadMarshaller marshaller = new InputEventStringPayloadMarshaller(protocolFactory); + + InputEventStringPayload e = InputEventStringPayload.builder() + .headerMember(HEADER_MEMBER) + .explicitPayloadStringMember(payload) + .build(); + + SdkHttpFullRequest marshalled = marshaller.marshall(e); + + assertThat(marshalled.headers().get(HEADER_MEMBER_NAME)).containsExactly(HEADER_MEMBER); + assertThat(marshalled.contentStreamProvider().get().newStream()) + .hasSameContentAs(SdkBytes.fromUtf8String(payload).asInputStream()); } } diff --git a/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/RestJsonEventStreamProtocolTest.java b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/RestJsonEventStreamProtocolTest.java index 2dd0293683ae..53a19f98f5ae 100644 --- a/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/RestJsonEventStreamProtocolTest.java +++ b/test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/RestJsonEventStreamProtocolTest.java @@ -32,9 +32,11 @@ import software.amazon.awssdk.services.protocolrestjsoncontenttype.model.ImplicitPayloadAndHeadersEvent; import software.amazon.awssdk.services.protocolrestjsoncontenttype.model.InputEventStream; import software.amazon.awssdk.services.protocolrestjsoncontenttype.model.ProtocolRestJsonContentTypeException; +import software.amazon.awssdk.services.protocolrestjsoncontenttype.model.StringAndHeadersEvent; import software.amazon.awssdk.services.protocolrestjsoncontenttype.transform.BlobAndHeadersEventMarshaller; import software.amazon.awssdk.services.protocolrestjsoncontenttype.transform.HeadersOnlyEventMarshaller; import software.amazon.awssdk.services.protocolrestjsoncontenttype.transform.ImplicitPayloadAndHeadersEventMarshaller; +import software.amazon.awssdk.services.protocolrestjsoncontenttype.transform.StringAndHeadersEventMarshaller; public class RestJsonEventStreamProtocolTest { private static final String EVENT_CONTENT_TYPE_HEADER = ":content-type"; @@ -89,6 +91,22 @@ public void blobAndHeadersEvent() { assertThat(content).isEqualTo("hello rest-json"); } + @Test + public void stringAndHeadersEvent() { + StringAndHeadersEventMarshaller marshaller = new StringAndHeadersEventMarshaller(protocolFactory()); + + StringAndHeadersEvent event = InputEventStream.stringAndHeadersEventBuilder() + .headerMember("hello rest-json") + .stringPayloadMember("hello rest-json") + .build(); + + SdkHttpFullRequest marshalledEvent = marshaller.marshall(event); + + assertThat(marshalledEvent.headers().get(EVENT_CONTENT_TYPE_HEADER)).containsExactly("text/plain"); + String content = contentAsString(marshalledEvent); + assertThat(content).isEqualTo("hello rest-json"); + } + @Test public void headersOnly() { HeadersOnlyEventMarshaller marshaller = new HeadersOnlyEventMarshaller(protocolFactory());