Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SigV4: Add host header only when not already provided #5608

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AWSSDKforJavav2-b3fbc61.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "AWS SDK for Java v2",
"contributor": "vsudilov",
"description": "SigV4: Add host header only when not already provided"
}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ public static void addHostHeader(SdkHttpRequest.Builder requestBuilder) {
// AWS4 requires that we sign the Host header, so we
// have to have it in the request by the time we sign.

// If the SdkHttpRequest has an associated Host header
// already set, prefer to use that.

if (requestBuilder.headers().get(SignerConstant.HOST) != null) {
return;
}

String host = requestBuilder.host();
if (!SdkHttpUtils.isUsingStandardPort(requestBuilder.protocol(), requestBuilder.port())) {
StringBuilder hostHeaderBuilder = new StringBuilder(host);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant;
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
import software.amazon.awssdk.identity.spi.AwsSessionCredentialsIdentity;

Expand Down Expand Up @@ -58,6 +59,7 @@ public void sign_computesSigningResult() {
assertEquals(expectedCanonicalRequestString, result.getCanonicalRequest().getCanonicalRequestString());
}


@Test
public void sign_withHeader_addsAuthHeaders() {
String expectedAuthorization = "AWS4-HMAC-SHA256 Credential=access/19700101/us-east-1/demo/aws4_request, " +
Expand All @@ -82,6 +84,21 @@ public void sign_withHeaderAndSessionCredentials_addsAuthHeadersAndTokenHeader()
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Security-Token")).hasValue("token");
}

@Test
public void sign_withHeaderAndSessionCredentials_correctSigningUsingProvidedHostHeader() {
String expectedAuthorization = "AWS4-HMAC-SHA256 Credential=access/19700101/us-east-1/demo/aws4_request, " +
"SignedHeaders=host;x-amz-archive-description;x-amz-content-sha256;x-amz-date;"
+ "x-amz-security-token, " +
"Signature=c8228e7bef8a72a450df38e6e935ce61fdb8989670b41d97cfc20d04bb76b10a";
SdkHttpRequest.Builder request = getRequest().putHeader(SignerConstant.HOST, "virtual-host.localhost");
V4RequestSigningResult result = header(getProperties(sessionCreds)).sign(request);

assertThat(result.getSignedRequest().firstMatchingHeader("Host")).hasValue("virtual-host.localhost");
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Date")).hasValue("19700101T000000Z");
assertThat(result.getSignedRequest().firstMatchingHeader("Authorization")).hasValue(expectedAuthorization);
assertThat(result.getSignedRequest().firstMatchingHeader("X-Amz-Security-Token")).hasValue("token");
}

@Test
public void sign_withQuery_addsAuthQueryParams() {
V4RequestSigningResult result = query(getProperties(creds)).sign(getRequest());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.client.config.RequestConfig;
Expand Down Expand Up @@ -55,7 +56,6 @@ public HttpRequestBase create(final HttpExecuteRequest request, final ApacheHttp
HttpRequestBase base = createApacheRequest(request, sanitizeUri(request.httpRequest()));
addHeadersToRequest(base, request.httpRequest());
addRequestConfig(base, request.httpRequest(), requestConfig);

return base;
}

Expand Down Expand Up @@ -172,7 +172,7 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
// it's already present, so we skip it here. We also skip the Host
// header to avoid sending it twice, which will interfere with some
// signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
for (String headerValue : value) {
httpRequest.addHeader(name, headerValue);
}
Expand All @@ -181,6 +181,11 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HttpHeaders.HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}
// Apache doesn't allow us to include the port in the host header if it's a standard port for that protocol. For that
// reason, we don't include the port when we sign the message. See {@link SdkHttpRequest#port()}.
return !SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@ public void createSetsHostHeaderByDefault() {
assertEquals("localhost:12345", hostHeaders[0].getValue());
}

@Test
public void createRespectsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void createRespectsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void putRequest_withTransferEncodingChunked_isChunkedAndDoesNotIncludeHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.Protocol;
import software.amazon.awssdk.http.SdkHttpMethod;
Expand Down Expand Up @@ -87,13 +88,19 @@ private void addHeadersToRequest(DefaultHttpRequest httpRequest, SdkHttpRequest
// Copy over any other headers already in our request
request.forEachHeader((name, value) -> {
// Skip the Host header to avoid sending it twice, which will interfere with some signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
value.forEach(h -> httpRequest.headers().add(name, h));
}
});
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}

return SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
? request.host()
: request.host() + ":" + request.port();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ public void adapt_hostHeaderSet() {
assertThat(hostHeaders).containsExactly("localhost:12345");
}

@Test
public void adapt_keepsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_keepsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_standardHttpsPort_omittedInHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down