Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/master/sra-identity-auth…
Browse files Browse the repository at this point in the history
…' into feature/master/sra-identity-auth-testing
  • Loading branch information
zoewangg committed Sep 29, 2023
2 parents b55a20f + 43b5b43 commit ec19254
Show file tree
Hide file tree
Showing 14 changed files with 471 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.Header;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope;
import software.amazon.awssdk.http.auth.aws.internal.signer.checksums.SdkChecksum;
Expand All @@ -52,6 +54,7 @@ public final class AwsChunkedV4aPayloadSigner implements V4aPayloadSigner {
private final CredentialScope credentialScope;
private final int chunkSize;
private final ChecksumAlgorithm checksumAlgorithm;
private final List<Pair<String, List<String>>> preExistingTrailers = new ArrayList<>();

private AwsChunkedV4aPayloadSigner(Builder builder) {
this.credentialScope = Validate.paramNotNull(builder.credentialScope, "CredentialScope");
Expand All @@ -65,16 +68,14 @@ public static Builder builder() {

@Override
public ContentStreamProvider sign(ContentStreamProvider payload, V4aContext v4aContext) {
SdkHttpRequest.Builder request = v4aContext.getSignedRequest();
moveContentLength(request);

InputStream inputStream = payload != null ? payload.newStream() : new StringInputStream("");
ChunkedEncodedInputStream.Builder chunkedEncodedInputStreamBuilder = ChunkedEncodedInputStream
.builder()
.inputStream(inputStream)
.chunkSize(chunkSize)
.header(chunk -> Integer.toHexString(chunk.length).getBytes(StandardCharsets.UTF_8));
setupPreExistingTrailers(chunkedEncodedInputStreamBuilder, request);

preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer));

switch (v4aContext.getSigningConfig().getSignedBodyValue()) {
case STREAMING_ECDSA_SIGNED_PAYLOAD: {
Expand All @@ -83,12 +84,12 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aContext v4aC
break;
}
case STREAMING_UNSIGNED_PAYLOAD_TRAILER:
setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder, request);
setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder);
break;
case STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER: {
RollingSigner rollingSigner = new RollingSigner(v4aContext.getSignature(), v4aContext.getSigningConfig());
chunkedEncodedInputStreamBuilder.addExtension(new SigV4aChunkExtensionProvider(rollingSigner, credentialScope));
setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder, request);
setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder);
chunkedEncodedInputStreamBuilder.addTrailer(
new SigV4aTrailerProvider(chunkedEncodedInputStreamBuilder.trailers(), rollingSigner, credentialScope)
);
Expand All @@ -101,49 +102,152 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aContext v4aC
return new ResettableContentStreamProvider(chunkedEncodedInputStreamBuilder::build);
}

/**
* Add the checksum as a chunk-trailer and add it to the request's trailer header.
* <p>
* The checksum-algorithm MUST be set if this is called, otherwise it will throw.
*/
private void setupChecksumTrailerIfNeeded(ChunkedEncodedInputStream.Builder builder, SdkHttpRequest.Builder request) {
if (checksumAlgorithm == null) {
return;
@Override
public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload, String checksum) {
long encodedContentLength = 0;
long contentLength = moveContentLength(request, payload != null ? payload.newStream() : new StringInputStream(""));
setupPreExistingTrailers(request);

// pre-existing trailers
encodedContentLength += calculateExistingTrailersLength();

switch (checksum) {
case STREAMING_ECDSA_SIGNED_PAYLOAD: {
long extensionsLength = 161; // ;chunk-signature:<sigv4a-ecsda hex signature, 144 bytes>
encodedContentLength += calculateChunksLength(contentLength, extensionsLength);
break;
}
case STREAMING_UNSIGNED_PAYLOAD_TRAILER:
if (checksumAlgorithm != null) {
encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm));
}
encodedContentLength += calculateChunksLength(contentLength, 0);
break;
case STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER: {
long extensionsLength = 161; // ;chunk-signature:<sigv4a-ecsda hex signature, 144 bytes>
encodedContentLength += calculateChunksLength(contentLength, extensionsLength);
if (checksumAlgorithm != null) {
encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm));
}
encodedContentLength += 170; // x-amz-trailer-signature:<sigv4a-ecsda hex signature, 144 bytes>\r\n
break;
}
default:
throw new UnsupportedOperationException();
}
SdkChecksum sdkChecksum = fromChecksumAlgorithm(checksumAlgorithm);
ChecksumInputStream checksumInputStream = new ChecksumInputStream(
builder.inputStream(),
Collections.singleton(sdkChecksum)
);
String checksumHeaderName = checksumHeaderName(checksumAlgorithm);

TrailerProvider checksumTrailer = new ChecksumTrailerProvider(sdkChecksum, checksumHeaderName);
// terminating \r\n
encodedContentLength += 2;

request.appendHeader(X_AMZ_TRAILER, checksumHeaderName);
builder.inputStream(checksumInputStream).addTrailer(checksumTrailer);
if (checksumAlgorithm != null) {
String checksumHeaderName = checksumHeaderName(checksumAlgorithm);
request.appendHeader(X_AMZ_TRAILER, checksumHeaderName);
}
request.putHeader(Header.CONTENT_LENGTH, Long.toString(encodedContentLength));
}

/**
* Create chunk-trailers for each pre-existing trailer given in the request.
* Set up a map of pre-existing trailer (headers) for the given request to be used when chunk-encoding the payload.
* <p>
* However, we need to validate that these are valid trailers. Since aws-chunked encoding adds the checksum as a trailer, it
* isn't part of the request headers, but other trailers MUST be present in the request-headers.
*/
private void setupPreExistingTrailers(ChunkedEncodedInputStream.Builder builder, SdkHttpRequest.Builder request) {
List<String> trailerHeaders = request.matchingHeaders(X_AMZ_TRAILER);

for (String header : trailerHeaders) {
private void setupPreExistingTrailers(SdkHttpRequest.Builder request) {
for (String header : request.matchingHeaders(X_AMZ_TRAILER)) {
List<String> values = request.matchingHeaders(header);
if (values.isEmpty()) {
throw new IllegalArgumentException(header + " must be present in the request headers to be a valid trailer.");
}

// Add the trailer to the aws-chunked stream-builder, and remove it from the request headers
builder.addTrailer(() -> Pair.of(header, values));
preExistingTrailers.add(Pair.of(header, values));
request.removeHeader(header);
}
}

private long calculateChunksLength(long contentLength, long extensionsLength) {
long lengthInBytes = 0;
long chunkHeaderLength = Integer.toHexString(chunkSize).length();
long numChunks = contentLength / chunkSize;

// normal chunks
// x<metadata>\r\n<data>\r\n
lengthInBytes += numChunks * (chunkHeaderLength + extensionsLength + 2 + chunkSize + 2);

// remaining chunk
// x<metadata>\r\n<data>\r\n
long remainingBytes = contentLength % chunkSize;
if (remainingBytes > 0) {
long remainingChunkHeaderLength = Long.toHexString(remainingBytes).length();
lengthInBytes += remainingChunkHeaderLength + extensionsLength + 2 + remainingBytes + 2;
}

// final chunk
// 0<metadata>\r\n
lengthInBytes += 1 + extensionsLength + 2;

return lengthInBytes;
}

private long calculateExistingTrailersLength() {
long lengthInBytes = 0;

for (Pair<String, List<String>> trailer : preExistingTrailers) {
// size of trailer
lengthInBytes += calculateTrailerLength(trailer);
}

return lengthInBytes;
}

private long calculateTrailerLength(Pair<String, List<String>> trailer) {
// size of trailer-header and colon
long lengthInBytes = trailer.left().length() + 1;

// size of trailer-values
for (String value : trailer.right()) {
lengthInBytes += value.length();
}

// size of commas between trailer-values, 1 less comma than # of values
lengthInBytes += trailer.right().size() - 1;

// terminating \r\n
return lengthInBytes + 2;
}

private long calculateChecksumTrailerLength(String checksumHeaderName) {
// size of checksum trailer-header and colon
long lengthInBytes = checksumHeaderName.length() + 1;

// get the base checksum for the algorithm
SdkChecksum sdkChecksum = fromChecksumAlgorithm(checksumAlgorithm);
// size of checksum value as hex-string
lengthInBytes += sdkChecksum.getChecksum().length();

// terminating \r\n
return lengthInBytes + 2;
}

/**
* Add the checksum as a trailer to the chunk-encoded stream.
* <p>
* If the checksum-algorithm is not present, then nothing is done.
*/
private void setupChecksumTrailerIfNeeded(ChunkedEncodedInputStream.Builder builder) {
if (checksumAlgorithm == null) {
return;
}
String checksumHeaderName = checksumHeaderName(checksumAlgorithm);
SdkChecksum sdkChecksum = fromChecksumAlgorithm(checksumAlgorithm);
ChecksumInputStream checksumInputStream = new ChecksumInputStream(
builder.inputStream(),
Collections.singleton(sdkChecksum)
);

TrailerProvider checksumTrailer = new ChecksumTrailerProvider(sdkChecksum, checksumHeaderName);

builder.inputStream(checksumInputStream).addTrailer(checksumTrailer);
}

static final class Builder {
private CredentialScope credentialScope;
private Integer chunkSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ private static SignedRequest doSign(SignRequest<? extends AwsCredentialsIdentity
.build();
}

SdkHttpRequest sanitizedRequest = sanitizeRequest(request.request());
SdkHttpRequest.Builder requestBuilder = request.request().toBuilder();

payloadSigner.beforeSigning(requestBuilder, request.payload().orElse(null), signingConfig.getSignedBodyValue());

SdkHttpRequest sanitizedRequest = sanitizeRequest(requestBuilder.build());

HttpRequest crtRequest = toRequest(sanitizedRequest, request.payload().orElse(null));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ public RollingSigner(byte[] seedSignature, AwsSigningConfig signingConfig) {
}

private static byte[] signChunk(byte[] chunkBody, byte[] previousSignature, AwsSigningConfig signingConfig) {
// All the config remains the same as signing config except the Signature Type.
AwsSigningConfig configCopy = signingConfig.clone();
configCopy.setSignatureType(AwsSigningConfig.AwsSignatureType.HTTP_REQUEST_CHUNK);

HttpRequestBodyStream crtBody = new CrtInputStream(() -> new ByteArrayInputStream(chunkBody));
return CompletableFutureUtils.joinLikeSync(AwsSigner.signChunk(crtBody, previousSignature, signingConfig));
return CompletableFutureUtils.joinLikeSync(AwsSigner.signChunk(crtBody, previousSignature, configCopy));
}

private static AwsSigningResult signTrailerHeaders(Map<String, List<String>> headerMap, byte[] previousSignature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpRequest;

/**
* An interface for defining how to sign a payload via SigV4a.
Expand All @@ -34,4 +35,10 @@ static V4aPayloadSigner create() {
* Given a payload and v4a-context, sign the payload via the SigV4a process.
*/
ContentStreamProvider sign(ContentStreamProvider payload, V4aContext v4Context);

/**
* Modify a request before it is signed, such as changing headers or query-parameters.
*/
default void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload, String checksum) {
}
}
Loading

0 comments on commit ec19254

Please sign in to comment.