Skip to content

Commit

Permalink
Merge pull request #1 from culldanx/culldanx-pr-2
Browse files Browse the repository at this point in the history
http-client: respect user host header
  • Loading branch information
vsudilov committed Sep 26, 2024
2 parents dfb0407 + bfa9cc1 commit 10b7b81
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.client.config.RequestConfig;
Expand Down Expand Up @@ -55,7 +56,6 @@ public HttpRequestBase create(final HttpExecuteRequest request, final ApacheHttp
HttpRequestBase base = createApacheRequest(request, sanitizeUri(request.httpRequest()));
addHeadersToRequest(base, request.httpRequest());
addRequestConfig(base, request.httpRequest(), requestConfig);

return base;
}

Expand Down Expand Up @@ -172,7 +172,7 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
// it's already present, so we skip it here. We also skip the Host
// header to avoid sending it twice, which will interfere with some
// signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
for (String headerValue : value) {
httpRequest.addHeader(name, headerValue);
}
Expand All @@ -181,6 +181,11 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HttpHeaders.HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}
// Apache doesn't allow us to include the port in the host header if it's a standard port for that protocol. For that
// reason, we don't include the port when we sign the message. See {@link SdkHttpRequest#port()}.
return !SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@ public void createSetsHostHeaderByDefault() {
assertEquals("localhost:12345", hostHeaders[0].getValue());
}

@Test
public void createRespectsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void createRespectsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpExecuteRequest request = HttpExecuteRequest.builder()
.request(sdkRequest)
.build();

HttpRequestBase result = instance.create(request, requestConfig);

Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST);
assertNotNull(hostHeaders);
assertEquals(1, hostHeaders.length);
assertEquals(hostOverride, hostHeaders[0].getValue());
}

@Test
public void putRequest_withTransferEncodingChunked_isChunkedAndDoesNotIncludeHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.http.Protocol;
import software.amazon.awssdk.http.SdkHttpMethod;
Expand Down Expand Up @@ -87,13 +88,19 @@ private void addHeadersToRequest(DefaultHttpRequest httpRequest, SdkHttpRequest
// Copy over any other headers already in our request
request.forEachHeader((name, value) -> {
// Skip the Host header to avoid sending it twice, which will interfere with some signing schemes.
if (!IGNORE_HEADERS.contains(name)) {
if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) {
value.forEach(h -> httpRequest.headers().add(name, h));
}
});
}

private String getHostHeaderValue(SdkHttpRequest request) {
// Respect any user-specified Host header when present
Optional<String> existingHostHeader = request.firstMatchingHeader(HOST);
if (existingHostHeader.isPresent()) {
return existingHostHeader.get();
}

return SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port())
? request.host()
: request.host() + ":" + request.port();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ public void adapt_hostHeaderSet() {
assertThat(hostHeaders).containsExactly("localhost:12345");
}

@Test
public void adapt_keepsUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("Host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_keepsLowercaseUserHostHeader() {
String hostOverride = "virtual.host:123";
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
.uri(URI.create("http://localhost:12345/"))
.method(SdkHttpMethod.HEAD)
.putHeader("host", hostOverride)
.build();
HttpRequest result = h1Adapter.adapt(sdkRequest);
List<String> hostHeaders = result.headers()
.getAll(HttpHeaderNames.HOST.toString());
assertThat(hostHeaders).containsExactly(hostOverride);
}

@Test
public void adapt_standardHttpsPort_omittedInHeader() {
SdkHttpRequest sdkRequest = SdkHttpRequest.builder()
Expand Down

0 comments on commit 10b7b81

Please sign in to comment.