Skip to content

Commit

Permalink
IA-4893: don't materialize the entire response (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidangb authored May 14, 2024
1 parent ad57bda commit 76d982c
Show file tree
Hide file tree
Showing 4 changed files with 982 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.boot.actuate.health.HealthEndpoint;
import org.springframework.boot.actuate.health.Status;
import org.springframework.lang.NonNull;
import org.springframework.util.StreamUtils;

public class RelayedHttpRequestProcessor {

Expand Down Expand Up @@ -283,7 +284,7 @@ public Result writeTargetResponseOnCaller(@NonNull TargetHttpResponse targetResp
Result result = Result.SUCCESS;
if (targetResponse.getBody().isPresent()) {
try {
outputStream.write(targetResponse.getBody().get().readAllBytes());
StreamUtils.copy(targetResponse.getBody().get(), outputStream);
} catch (IOException e) {
logger.error("Failed to write response body to the remote client.", e);
result = Result.FAILURE;
Expand Down
4 changes: 4 additions & 0 deletions service/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ spring:
cache-names: expiresAt
caffeine.spec: maximumSize=100,expireAfterWrite=90s

logging:
level:
com.microsoft.azure.relay.RelayLogger: WARN

listener:
# Connection string for the Azure Relay instance.
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasKey;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -19,6 +22,8 @@
import com.microsoft.azure.relay.RelayedHttpListenerResponse;
import com.microsoft.azure.relay.TrackingContext;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand All @@ -30,13 +35,17 @@
import java.net.http.HttpHeaders;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.lang3.RandomStringUtils;
import org.broadinstitute.listener.config.CorsSupportProperties;
import org.broadinstitute.listener.relay.InvalidRelayTargetException;
import org.broadinstitute.listener.relay.http.RelayedHttpRequestProcessor.Result;
Expand All @@ -51,10 +60,12 @@
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.boot.actuate.health.HealthComponent;
import org.springframework.boot.actuate.health.HealthEndpoint;
import org.springframework.boot.actuate.health.Status;
import org.springframework.util.StreamUtils;

@ExtendWith(MockitoExtension.class)
class RelayedHttpRequestProcessorTest {
Expand All @@ -69,7 +80,6 @@ class RelayedHttpRequestProcessorTest {

@Mock private HttpClient httpClient;

@Mock private InputStream mockBody;
@Mock private RelayedHttpListenerContext context;
@Mock private RelayedHttpListenerRequest listenerRequest;
@Mock private RelayedHttpRequest request;
Expand Down Expand Up @@ -172,9 +182,14 @@ void writeTargetResponseOnCaller_responseIsWrittenBackToCaller() throws IOExcept

processor.writeTargetResponseOnCaller(targetHttpResponse);

verify(responseStream).write(responseData.capture());
// writeTargetResponseOnCaller relies on StreamUtils.copy, which specifies
// an offset and length in its call to write()
verify(responseStream).write(responseData.capture(), anyInt(), anyInt());

assertThat(new String(responseData.getValue()), equalTo(BODY_CONTENT));
// StreamUtils.copy buffers its writes in arrays of 4096 bytes (by default). Since this
// test uses a small response, we need to trim the empty part of the array before comparing
// the result
assertThat(new String(responseData.getValue()).trim(), equalTo(BODY_CONTENT));
}

@Test
Expand Down Expand Up @@ -241,16 +256,113 @@ void writeTargetResponseOnCaller_setsNoSniff() throws IOException {

@Test
void writeTargetResponseOnCaller_withBodyResponseStreamsClose() throws IOException {
// this test uses a spy to verify we close its stream
InputStream spyBody =
Mockito.spy(new ByteArrayInputStream(BODY_CONTENT.getBytes(StandardCharsets.UTF_8)));

when(targetHttpResponse.getContext()).thenReturn(context);
when(targetHttpResponse.getBody()).thenReturn(Optional.of(spyBody));
when(targetHttpResponse.getStatusCode()).thenReturn(200);
when(context.getResponse()).thenReturn(listenerResponse);
when(targetHttpResponse.getCallerResponseOutputStream()).thenReturn(responseStream);

processor.writeTargetResponseOnCaller(targetHttpResponse);

verify(responseStream).close();
verify(spyBody).close();
}

/**
* Given a large - multi-MB - http response, ensure the response is buffered back to the caller in
* multiple writes.
*
* @throws IOException on temp file error
*/
@Test
void writeTargetResponseOnCaller_withLargeBodyContent() throws IOException {
// write a bunch of junk data to a temp file
int numBufferChunks = 2000; // 2000 chunks * 4096 bytes/chunk = ~7.8MB
Path inputFile = Files.createTempFile("input-", ".tmp");
FileOutputStream fileOutputStream = new FileOutputStream(inputFile.toFile());
for (int i = 0; i < numBufferChunks; i++) {
fileOutputStream.write(
RandomStringUtils.randomAlphanumeric(StreamUtils.BUFFER_SIZE)
.getBytes(StandardCharsets.UTF_8));
}
fileOutputStream.close();

// create an InputStream from the temp file
FileInputStream fileInputStream = Mockito.spy(new FileInputStream(inputFile.toFile()));

// set the InputStream as the HTTP response
when(targetHttpResponse.getContext()).thenReturn(context);
when(targetHttpResponse.getBody()).thenReturn(Optional.of(mockBody));
when(targetHttpResponse.getBody()).thenReturn(Optional.of(fileInputStream));
when(targetHttpResponse.getStatusCode()).thenReturn(200);
when(context.getResponse()).thenReturn(listenerResponse);
when(targetHttpResponse.getCallerResponseOutputStream()).thenReturn(responseStream);

// write HTTP response to caller
processor.writeTargetResponseOnCaller(targetHttpResponse);

// verify that the HTTP response was buffered to the caller in chunks; we should have written
// to the caller ${numBufferChunks} times, because the temp file is of size
// ${numBufferChunks * StreamUtils.BUFFER_SIZE}
verify(responseStream, times(numBufferChunks)).write(any(), anyInt(), anyInt());

// verify everything was closed
verify(responseStream).close();
verify(mockBody).close();
verify(fileInputStream).close();

// clean up
Files.delete(inputFile);
}

/**
* Given a medium-sized (50KB) http response file, ensure that the response is buffered to the
* caller and that what the caller receives is equivalent to http response
*
* @throws IOException on temp file error
*/
@Test
void writeTargetResponseOnCaller_withMultipleChunks() throws IOException {
// create an InputStream from a sample text file
InputStream fileInputStream =
Mockito.spy(
Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("sample-text.txt")));

// create a temp file to serve as our output stream
Path outputFile = Files.createTempFile("output-", ".tmp");
FileOutputStream responseOutputStream = Mockito.spy(new FileOutputStream(outputFile.toFile()));

// set the InputStream as the HTTP response
when(targetHttpResponse.getContext()).thenReturn(context);
when(targetHttpResponse.getBody()).thenReturn(Optional.of(fileInputStream));
when(targetHttpResponse.getStatusCode()).thenReturn(200);
when(context.getResponse()).thenReturn(listenerResponse);
when(targetHttpResponse.getCallerResponseOutputStream()).thenReturn(responseOutputStream);

// write HTTP response to caller
processor.writeTargetResponseOnCaller(targetHttpResponse);

// verify that the HTTP response was buffered to the caller in chunks. We should have written
// more than one chunk given the size of the input file.
verify(responseOutputStream, atLeast(1)).write(any(), anyInt(), anyInt());

// verify everything was closed
verify(responseOutputStream).close();
verify(fileInputStream).close();

// compare file contents
String expected =
new String(
Objects.requireNonNull(ClassLoader.getSystemResourceAsStream("sample-text.txt"))
.readAllBytes());
String actual = new String(Files.readAllBytes(outputFile));

assertThat("file contents differ", actual.equals(expected));

// clean up
Files.delete(outputFile);
}

@Test
Expand Down
Loading

0 comments on commit 76d982c

Please sign in to comment.