Skip to content

Commit

Permalink
Improve trait lookup using TraitType
Browse files Browse the repository at this point in the history
  • Loading branch information
sugmanue committed Sep 5, 2024
1 parent e3a8727 commit 0d997df
Show file tree
Hide file tree
Showing 36 changed files with 315 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import software.amazon.awssdk.core.traits.JsonValueTrait;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.StringUtils;
Expand All @@ -35,7 +36,7 @@
public final class HeaderMarshaller {

public static final JsonMarshaller<String> STRING = new SimpleHeaderMarshaller<>(
(val, field) -> field.containsTrait(JsonValueTrait.class) ?
(val, field) -> field.containsTrait(JsonValueTrait.class, TraitType.JSON_VALUE_TRAIT) ?
BinaryUtils.toBase64(val.getBytes(StandardCharsets.UTF_8)) : val);

public static final JsonMarshaller<Integer> INTEGER = new SimpleHeaderMarshaller<>(ValueToStringConverter.FROM_INTEGER);
Expand All @@ -61,7 +62,7 @@ public final class HeaderMarshaller {
if (isNullOrEmpty(list)) {
return;
}
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class).memberFieldInfo();
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class, TraitType.LIST_TRAIT).memberFieldInfo();
for (Object listValue : list) {
if (shouldSkipElement(listValue)) {
continue;
Expand All @@ -72,7 +73,7 @@ public final class HeaderMarshaller {
};

public static final JsonMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null", paramName));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.PayloadTrait;
import software.amazon.awssdk.core.traits.TimestampFormatTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.protocols.core.InstantToString;
import software.amazon.awssdk.protocols.core.OperationInfo;
Expand Down Expand Up @@ -227,7 +228,7 @@ private boolean isExplicitStringPayload(SdkField<?> field) {
}

private boolean isExplicitPayloadMember(SdkField<?> field) {
return field.containsTrait(PayloadTrait.class);
return field.containsTrait(PayloadTrait.class, TraitType.PAYLOAD_TRAIT);
}

private void marshallExplicitJsonPayload(SdkField<?> field, Object val) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;

@SdkInternalApi
Expand Down Expand Up @@ -65,7 +66,7 @@ public final class QueryParamMarshaller {
};

public static final JsonMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null", paramName));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TimestampFormatTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.core.util.SdkAutoConstructList;
import software.amazon.awssdk.core.util.SdkAutoConstructMap;
import software.amazon.awssdk.protocols.json.StructuredJsonGenerator;
Expand All @@ -39,7 +40,7 @@
public final class SimpleTypeJsonMarshaller {

public static final JsonMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null",
Optional.ofNullable(paramName)
.orElseGet(() -> "paramName null")));
Expand Down Expand Up @@ -122,7 +123,8 @@ public void marshall(Boolean val, StructuredJsonGenerator jsonGenerator, JsonMar
if (paramName != null) {
jsonGenerator.writeFieldName(paramName);
}
TimestampFormatTrait trait = sdkField != null ? sdkField.getTrait(TimestampFormatTrait.class) : null;
TimestampFormatTrait trait = sdkField != null ? sdkField.getTrait(TimestampFormatTrait.class,
TraitType.TIMESTAMP_FORMAT_TRAIT) : null;
if (trait != null) {
switch (trait.format()) {
case UNIX_TIMESTAMP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.JsonValueTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.core.StringToValueConverter;
import software.amazon.awssdk.protocols.jsoncore.JsonNode;
import software.amazon.awssdk.utils.BinaryUtils;
Expand Down Expand Up @@ -59,7 +60,7 @@ private HeaderUnmarshaller() {
*/
private static String unmarshallStringHeader(String value,
SdkField<String> field) {
return field.containsTrait(JsonValueTrait.class) ?
return field.containsTrait(JsonValueTrait.class, TraitType.JSON_VALUE_TRAIT) ?
new String(BinaryUtils.fromBase64(value), StandardCharsets.UTF_8) : value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.MapTrait;
import software.amazon.awssdk.core.traits.PayloadTrait;
Expand Down Expand Up @@ -179,7 +180,7 @@ private static Document getDocumentFromJsonContent(JsonNode jsonContent) {
return null;
}

SdkField<Object> valueInfo = field.getTrait(MapTrait.class).valueFieldInfo();
SdkField<Object> valueInfo = field.getTrait(MapTrait.class, TraitType.MAP_TRAIT).valueFieldInfo();
JsonUnmarshaller<Object> unmarshaller = context.getUnmarshaller(valueInfo.location(), valueInfo.marshallingType());
Map<String, JsonNode> asObject = jsonContent.asObject();
Map<String, Object> map = new HashMap<>(asObject.size());
Expand All @@ -194,7 +195,7 @@ private static List<?> unmarshallList(JsonUnmarshallerContext context, JsonNode
return null;
}

SdkField<Object> memberInfo = field.getTrait(ListTrait.class).memberFieldInfo();
SdkField<Object> memberInfo = field.getTrait(ListTrait.class, TraitType.LIST_TRAIT).memberFieldInfo();
List<JsonNode> asArray = jsonContent.asArray();
List<Object> result = new ArrayList<>(asArray.size());
for (JsonNode node : asArray) {
Expand Down Expand Up @@ -249,7 +250,7 @@ private boolean isExplicitStringPayloadMember(SdkField<?> f) {
}

private static boolean isExplicitPayloadMember(SdkField<?> f) {
return f.containsTrait(PayloadTrait.class);
return f.containsTrait(PayloadTrait.class, TraitType.PAYLOAD_TRAIT);
}

private boolean isPayloadMemberOnUnmarshall(SdkField<?> f) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.core.util.SdkAutoConstructList;

/**
Expand Down Expand Up @@ -52,7 +53,7 @@ public void marshall(QueryMarshallerContext context, String path, List<?> val, S
return;
}
for (int i = 0; i < val.size(); i++) {
ListTrait listTrait = sdkField.getTrait(ListTrait.class);
ListTrait listTrait = sdkField.getTrait(ListTrait.class, TraitType.LIST_TRAIT);
String listPath = pathResolver.resolve(path, i, listTrait);
QueryMarshaller<Object> marshaller = context.marshallerRegistry().getMarshaller(
((SdkField<?>) listTrait.memberFieldInfo()).marshallingType(), val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.MapTrait;
import software.amazon.awssdk.core.traits.TraitType;

@SdkInternalApi
public class MapQueryMarshaller implements QueryMarshaller<Map<String, ?>> {

@Override
public void marshall(QueryMarshallerContext context, String path, Map<String, ?> val, SdkField<Map<String, ?>> sdkField) {
MapTrait mapTrait = sdkField.getTrait(MapTrait.class);
MapTrait mapTrait = sdkField.getTrait(MapTrait.class, TraitType.MAP_TRAIT);
AtomicInteger entryNum = new AtomicInteger(1);
val.forEach((key, value) -> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.query.unmarshall.XmlElement;

@SdkInternalApi
public final class ListQueryUnmarshaller implements QueryUnmarshaller<List<?>> {

@Override
public List<?> unmarshall(QueryUnmarshallerContext context, List<XmlElement> content, SdkField<List<?>> field) {
ListTrait listTrait = field.getTrait(ListTrait.class);
ListTrait listTrait = field.getTrait(ListTrait.class, TraitType.LIST_TRAIT);
List<Object> list = new ArrayList<>();
getMembers(content, listTrait).forEach(member -> {
QueryUnmarshaller unmarshaller = context.getUnmarshaller(listTrait.memberFieldInfo().location(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.traits.MapTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.query.unmarshall.XmlElement;

@SdkInternalApi
Expand All @@ -31,7 +32,7 @@ public final class MapQueryUnmarshaller implements QueryUnmarshaller<Map<String,
@Override
public Map<String, ?> unmarshall(QueryUnmarshallerContext context, List<XmlElement> content, SdkField<Map<String, ?>> field) {
Map<String, Object> map = new HashMap<>();
MapTrait mapTrait = field.getTrait(MapTrait.class);
MapTrait mapTrait = field.getTrait(MapTrait.class, TraitType.MAP_TRAIT);
SdkField mapValueSdkField = mapTrait.valueFieldInfo();

getEntries(content, mapTrait).forEach(entry -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.PayloadTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.http.SdkHttpFullResponse;
import software.amazon.awssdk.protocols.core.StringToInstant;
import software.amazon.awssdk.protocols.core.StringToValueConverter;
Expand Down Expand Up @@ -90,7 +91,7 @@ public <TypeT extends SdkPojo> Pair<TypeT, Map<String, String>> unmarshall(SdkPo
private boolean responsePayloadIsBlob(SdkPojo sdkPojo) {
return sdkPojo.sdkFields().stream()
.anyMatch(field -> field.marshallingType() == MarshallingType.SDK_BYTES &&
field.containsTrait(PayloadTrait.class));
field.containsTrait(PayloadTrait.class, TraitType.PAYLOAD_TRAIT));
}

/**
Expand Down Expand Up @@ -128,7 +129,9 @@ private String metadataKeyName(XmlElement c) {
private SdkPojo unmarshall(QueryUnmarshallerContext context, SdkPojo sdkPojo, XmlElement root) {
if (root != null) {
for (SdkField<?> field : sdkPojo.sdkFields()) {
if (field.containsTrait(PayloadTrait.class) && field.marshallingType() == MarshallingType.SDK_BYTES) {
if (field.containsTrait(PayloadTrait.class, TraitType.PAYLOAD_TRAIT)
&& field.marshallingType() == MarshallingType.SDK_BYTES
) {
field.set(sdkPojo, SdkBytes.fromUtf8String(root.textContent()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;
import software.amazon.awssdk.utils.StringUtils;

Expand Down Expand Up @@ -79,7 +80,7 @@ public void marshall(List<?> list, XmlMarshallerContext context, String paramNam
if (!shouldEmit(list)) {
return;
}
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class).memberFieldInfo();
SdkField memberFieldInfo = sdkField.getRequiredTrait(ListTrait.class, TraitType.LIST_TRAIT).memberFieldInfo();
for (Object listValue : list) {
if (shouldSkipElement(listValue)) {
continue;
Expand All @@ -102,7 +103,7 @@ protected boolean shouldEmit(List list) {
};

public static final XmlMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null", paramName));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.MapTrait;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.protocols.core.ValueToStringConverter;

@SdkInternalApi
Expand Down Expand Up @@ -62,11 +63,11 @@ public final class QueryParamMarshaller {
return;
}

MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class);
MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class, TraitType.MAP_TRAIT);
SdkField valueField = mapTrait.valueFieldInfo();

for (Map.Entry<String, ?> entry : map.entrySet()) {
if (valueField.containsTrait(ListTrait.class)) {
if (valueField.containsTrait(ListTrait.class, TraitType.LIST_TRAIT)) {
((List<?>) entry.getValue()).forEach(val -> {
context.marshallerRegistry().getMarshaller(MarshallLocation.QUERY_PARAM, val)
.marshall(val, context, entry.getKey(), null);
Expand All @@ -81,7 +82,7 @@ public final class QueryParamMarshaller {
};

public static final XmlMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null", paramName));
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import software.amazon.awssdk.core.traits.ListTrait;
import software.amazon.awssdk.core.traits.MapTrait;
import software.amazon.awssdk.core.traits.RequiredTrait;
import software.amazon.awssdk.core.traits.TraitType;
import software.amazon.awssdk.core.traits.XmlAttributeTrait;
import software.amazon.awssdk.core.traits.XmlAttributesTrait;
import software.amazon.awssdk.core.util.SdkAutoConstructList;
Expand Down Expand Up @@ -83,7 +84,7 @@ public void marshall(List<?> val, XmlMarshallerContext context, String paramName
@Override
public void marshall(List<?> list, XmlMarshallerContext context, String paramName,
SdkField<List<?>> sdkField, ValueToStringConverter.ValueToString<List<?>> converter) {
ListTrait listTrait = sdkField.getRequiredTrait(ListTrait.class);
ListTrait listTrait = sdkField.getRequiredTrait(ListTrait.class, TraitType.LIST_TRAIT);

if (!listTrait.isFlattened()) {
context.xmlGenerator().startElement(paramName);
Expand Down Expand Up @@ -125,7 +126,7 @@ protected boolean shouldEmit(List list, String paramName) {
public void marshall(Map<String, ?> map, XmlMarshallerContext context, String paramName,
SdkField<Map<String, ?>> sdkField, ValueToStringConverter.ValueToString<Map<String, ?>> converter) {

MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class);
MapTrait mapTrait = sdkField.getRequiredTrait(MapTrait.class, TraitType.MAP_TRAIT);

for (Map.Entry<String, ?> entry : map.entrySet()) {
context.xmlGenerator().startElement("entry");
Expand All @@ -144,7 +145,7 @@ protected boolean shouldEmit(Map map, String paramName) {
};

public static final XmlMarshaller<Void> NULL = (val, context, paramName, sdkField) -> {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class)) {
if (Objects.nonNull(sdkField) && sdkField.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(String.format("Parameter '%s' must not be null", paramName));
}
};
Expand Down Expand Up @@ -180,8 +181,11 @@ public void marshall(T val, XmlMarshallerContext context, String paramName, SdkF
return;
}

if (sdkField != null && sdkField.getOptionalTrait(XmlAttributesTrait.class).isPresent()) {
XmlAttributesTrait attributeTrait = sdkField.getTrait(XmlAttributesTrait.class);
boolean hasXmlAttributesTrait = sdkField != null &&
sdkField.getOptionalTrait(XmlAttributesTrait.class,
TraitType.XML_ATTRIBUTES_TRAIT).isPresent();
if (hasXmlAttributesTrait) {
XmlAttributesTrait attributeTrait = sdkField.getTrait(XmlAttributesTrait.class, TraitType.XML_ATTRIBUTES_TRAIT);
Map<String, String> attributes = attributeTrait.attributes()
.entrySet()
.stream()
Expand Down Expand Up @@ -209,7 +213,8 @@ protected boolean shouldEmit(T val, String paramName) {
}

private boolean isXmlAttribute(SdkField<T> sdkField) {
return sdkField != null && sdkField.getOptionalTrait(XmlAttributeTrait.class).isPresent();
return sdkField != null && sdkField.getOptionalTrait(XmlAttributeTrait.class,
TraitType.XML_ATTRIBUTE_TRAIT).isPresent();
}
}

Expand Down
Loading

0 comments on commit 0d997df

Please sign in to comment.