From bfa9cc1a42d16ffb4e15cf3762bbb33b98d48cff Mon Sep 17 00:00:00 2001 From: Daniel Cullen Date: Wed, 25 Sep 2024 22:33:33 -0400 Subject: [PATCH] http-client: respect user host header --- .../impl/ApacheHttpRequestFactory.java | 9 ++++- .../impl/ApacheHttpRequestFactoryTest.java | 40 +++++++++++++++++++ .../nio/netty/internal/RequestAdapter.java | 9 ++++- .../netty/internal/RequestAdapterTest.java | 28 +++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/http-clients/apache-client/src/main/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactory.java b/http-clients/apache-client/src/main/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactory.java index 7f0c484ba05c..cfb22343ba3f 100644 --- a/http-clients/apache-client/src/main/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactory.java +++ b/http-clients/apache-client/src/main/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactory.java @@ -20,6 +20,7 @@ import java.net.URI; import java.util.Arrays; import java.util.List; +import java.util.Optional; import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.client.config.RequestConfig; @@ -55,7 +56,6 @@ public HttpRequestBase create(final HttpExecuteRequest request, final ApacheHttp HttpRequestBase base = createApacheRequest(request, sanitizeUri(request.httpRequest())); addHeadersToRequest(base, request.httpRequest()); addRequestConfig(base, request.httpRequest(), requestConfig); - return base; } @@ -172,7 +172,7 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req // it's already present, so we skip it here. We also skip the Host // header to avoid sending it twice, which will interfere with some // signing schemes. - if (!IGNORE_HEADERS.contains(name)) { + if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) { for (String headerValue : value) { httpRequest.addHeader(name, headerValue); } @@ -181,6 +181,11 @@ private void addHeadersToRequest(HttpRequestBase httpRequest, SdkHttpRequest req } private String getHostHeaderValue(SdkHttpRequest request) { + // Respect any user-specified Host header when present + Optional existingHostHeader = request.firstMatchingHeader(HttpHeaders.HOST); + if (existingHostHeader.isPresent()) { + return existingHostHeader.get(); + } // Apache doesn't allow us to include the port in the host header if it's a standard port for that protocol. For that // reason, we don't include the port when we sign the message. See {@link SdkHttpRequest#port()}. return !SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port()) diff --git a/http-clients/apache-client/src/test/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactoryTest.java b/http-clients/apache-client/src/test/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactoryTest.java index 6699434a5351..7e5a13bf707f 100644 --- a/http-clients/apache-client/src/test/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactoryTest.java +++ b/http-clients/apache-client/src/test/java/software/amazon/awssdk/http/apache/internal/impl/ApacheHttpRequestFactoryTest.java @@ -71,6 +71,46 @@ public void createSetsHostHeaderByDefault() { assertEquals("localhost:12345", hostHeaders[0].getValue()); } + @Test + public void createRespectsUserHostHeader() { + String hostOverride = "virtual.host:123"; + SdkHttpRequest sdkRequest = SdkHttpRequest.builder() + .uri(URI.create("http://localhost:12345/")) + .method(SdkHttpMethod.HEAD) + .putHeader("Host", hostOverride) + .build(); + HttpExecuteRequest request = HttpExecuteRequest.builder() + .request(sdkRequest) + .build(); + + HttpRequestBase result = instance.create(request, requestConfig); + + Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST); + assertNotNull(hostHeaders); + assertEquals(1, hostHeaders.length); + assertEquals(hostOverride, hostHeaders[0].getValue()); + } + + @Test + public void createRespectsLowercaseUserHostHeader() { + String hostOverride = "virtual.host:123"; + SdkHttpRequest sdkRequest = SdkHttpRequest.builder() + .uri(URI.create("http://localhost:12345/")) + .method(SdkHttpMethod.HEAD) + .putHeader("host", hostOverride) + .build(); + HttpExecuteRequest request = HttpExecuteRequest.builder() + .request(sdkRequest) + .build(); + + HttpRequestBase result = instance.create(request, requestConfig); + + Header[] hostHeaders = result.getHeaders(HttpHeaders.HOST); + assertNotNull(hostHeaders); + assertEquals(1, hostHeaders.length); + assertEquals(hostOverride, hostHeaders[0].getValue()); + } + @Test public void putRequest_withTransferEncodingChunked_isChunkedAndDoesNotIncludeHeader() { SdkHttpRequest sdkRequest = SdkHttpRequest.builder() diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapter.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapter.java index 14e3ba8e99d2..c2c1a4a6848e 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapter.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapter.java @@ -24,6 +24,7 @@ import io.netty.handler.codec.http2.HttpConversionUtil.ExtensionHeaderNames; import java.util.Collections; import java.util.List; +import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.http.Protocol; import software.amazon.awssdk.http.SdkHttpMethod; @@ -87,13 +88,19 @@ private void addHeadersToRequest(DefaultHttpRequest httpRequest, SdkHttpRequest // Copy over any other headers already in our request request.forEachHeader((name, value) -> { // Skip the Host header to avoid sending it twice, which will interfere with some signing schemes. - if (!IGNORE_HEADERS.contains(name)) { + if (IGNORE_HEADERS.stream().noneMatch(name::equalsIgnoreCase)) { value.forEach(h -> httpRequest.headers().add(name, h)); } }); } private String getHostHeaderValue(SdkHttpRequest request) { + // Respect any user-specified Host header when present + Optional existingHostHeader = request.firstMatchingHeader(HOST); + if (existingHostHeader.isPresent()) { + return existingHostHeader.get(); + } + return SdkHttpUtils.isUsingStandardPort(request.protocol(), request.port()) ? request.host() : request.host() + ":" + request.port(); diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapterTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapterTest.java index 5b3f660c4196..273dcf5cba72 100644 --- a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapterTest.java +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/internal/RequestAdapterTest.java @@ -105,6 +105,34 @@ public void adapt_hostHeaderSet() { assertThat(hostHeaders).containsExactly("localhost:12345"); } + @Test + public void adapt_keepsUserHostHeader() { + String hostOverride = "virtual.host:123"; + SdkHttpRequest sdkRequest = SdkHttpRequest.builder() + .uri(URI.create("http://localhost:12345/")) + .method(SdkHttpMethod.HEAD) + .putHeader("Host", hostOverride) + .build(); + HttpRequest result = h1Adapter.adapt(sdkRequest); + List hostHeaders = result.headers() + .getAll(HttpHeaderNames.HOST.toString()); + assertThat(hostHeaders).containsExactly(hostOverride); + } + + @Test + public void adapt_keepsLowercaseUserHostHeader() { + String hostOverride = "virtual.host:123"; + SdkHttpRequest sdkRequest = SdkHttpRequest.builder() + .uri(URI.create("http://localhost:12345/")) + .method(SdkHttpMethod.HEAD) + .putHeader("host", hostOverride) + .build(); + HttpRequest result = h1Adapter.adapt(sdkRequest); + List hostHeaders = result.headers() + .getAll(HttpHeaderNames.HOST.toString()); + assertThat(hostHeaders).containsExactly(hostOverride); + } + @Test public void adapt_standardHttpsPort_omittedInHeader() { SdkHttpRequest sdkRequest = SdkHttpRequest.builder()