diff --git a/src/main/scala/alpakka/sse_to_elasticsearch/NerRequestOpenAI.java b/src/main/scala/alpakka/sse_to_elasticsearch/NerRequestOpenAI.java index 2a4eb7ec..73567257 100644 --- a/src/main/scala/alpakka/sse_to_elasticsearch/NerRequestOpenAI.java +++ b/src/main/scala/alpakka/sse_to_elasticsearch/NerRequestOpenAI.java @@ -1,12 +1,14 @@ package alpakka.sse_to_elasticsearch; -import org.apache.commons.io.IOUtils; import org.apache.hc.client5.http.classic.methods.HttpPost; -import org.apache.hc.client5.http.config.RequestConfig; +import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.DefaultHttpRequestRetryStrategy; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; import org.apache.hc.core5.util.TimeValue; import org.apache.hc.core5.util.Timeout; @@ -15,7 +17,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; /** @@ -47,24 +48,30 @@ public String run(String text) { // For NER this means we get not just 'Person' but also 'Organisation', 'Location' requestParams.put("temperature", 0.2); - HttpPost request = new HttpPost("https://api.openai.com/v1/completions"); + String endpointURL = "https://api.openai.com/v1/completions"; + HttpPost request = new HttpPost(endpointURL); request.setHeader("Authorization", "Bearer " + API_KEY); StringEntity requestEntity = new StringEntity( requestParams.toString(), ContentType.APPLICATION_JSON); request.setEntity(requestEntity); - RequestConfig timeoutsConfig = RequestConfig.custom() - .setConnectTimeout(Timeout.of(DELAY_TO_RETRY_SECONDS, TimeUnit.SECONDS)).build(); + PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = PoolingHttpClientConnectionManagerBuilder.create(); + connectionManagerBuilder.setDefaultConnectionConfig(ConnectionConfig.custom() + .setSocketTimeout(Timeout.of(DELAY_TO_RETRY_SECONDS, TimeUnit.SECONDS)) + .build()); try (CloseableHttpClient httpClient = HttpClientBuilder.create() - .setDefaultRequestConfig(timeoutsConfig) + .setConnectionManager(connectionManagerBuilder.build()) .setRetryStrategy(new DefaultHttpRequestRetryStrategy(3, TimeValue.ofMinutes(1L))) .build()) { - return IOUtils.toString(httpClient.execute(request).getEntity().getContent(), StandardCharsets.UTF_8); + return httpClient.execute(request, response -> { + HttpEntity entity = response.getEntity(); + return entity != null ? EntityUtils.toString(entity) : "N/A"; + }); } catch (IOException e) { - LOGGER.warn("Unable to get result from openai completions endpoint. Cause: ", e); - return "N/A"; + LOGGER.warn("Connection issue while accessing openai API endpoint: {}. Cause: ", endpointURL, e); + throw new RuntimeException(e); } } } diff --git a/src/main/scala/sample/stream_shared_state/DownloaderRetry.java b/src/main/scala/sample/stream_shared_state/DownloaderRetry.java index 270cee31..c922d2d7 100644 --- a/src/main/scala/sample/stream_shared_state/DownloaderRetry.java +++ b/src/main/scala/sample/stream_shared_state/DownloaderRetry.java @@ -5,9 +5,10 @@ import org.apache.hc.client5.http.HttpRequestRetryStrategy; import org.apache.hc.client5.http.HttpResponseException; import org.apache.hc.client5.http.classic.methods.HttpGet; -import org.apache.hc.client5.http.config.RequestConfig; +import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder; import org.apache.hc.core5.http.ClassicHttpResponse; import org.apache.hc.core5.http.HttpRequest; import org.apache.hc.core5.http.HttpResponse; @@ -46,11 +47,13 @@ public static void main(String[] args) throws Exception { public Path download(int traceID, URI url, Path destinationFile) { LOGGER.info("TRACE_ID: {} about to download...", traceID); - RequestConfig timeoutsConfig = RequestConfig.custom() - .setConnectTimeout(Timeout.of(DELAY_TO_RETRY_SECONDS, TimeUnit.SECONDS)).build(); + PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = PoolingHttpClientConnectionManagerBuilder.create(); + connectionManagerBuilder.setDefaultConnectionConfig(ConnectionConfig.custom() + .setSocketTimeout(Timeout.of(DELAY_TO_RETRY_SECONDS, TimeUnit.SECONDS)) + .build()); try (CloseableHttpClient httpClient = HttpClientBuilder.create() - .setDefaultRequestConfig(timeoutsConfig) + .setConnectionManager(connectionManagerBuilder.build()) .setRetryStrategy(new CustomHttpRequestRetryStrategy()) .build()) { Path localPath = httpClient.execute(new HttpGet(url), new HttpResponseHandler(destinationFile)); diff --git a/src/main/scala/tools/OpenAICompletions.java b/src/main/scala/tools/OpenAICompletions.java index 34d749d7..ff11d500 100644 --- a/src/main/scala/tools/OpenAICompletions.java +++ b/src/main/scala/tools/OpenAICompletions.java @@ -1,12 +1,13 @@ package tools; -import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.hc.client5.http.classic.methods.HttpPost; import org.apache.hc.client5.http.config.RequestConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; import org.apache.hc.core5.util.Timeout; import org.json.JSONArray; @@ -15,7 +16,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.concurrent.TimeUnit; /** @@ -100,11 +100,14 @@ public String postRequest(JSONObject requestParams, String endpointURL) { .setDefaultRequestConfig(timeoutsConfig) .setRetryStrategy(new HttpRequestRetryStrategyOpenAI()) .build()) { - return IOUtils.toString(httpClient.execute(request).getEntity().getContent(), StandardCharsets.UTF_8); + return httpClient.execute(request, response -> { + HttpEntity entity = response.getEntity(); + return entity != null ? EntityUtils.toString(entity) : "N/A"; + }); } catch (IOException e) { - LOGGER.warn("Connection issue while accessing openai endpoint. Cause: ", e); + LOGGER.warn("Connection issue while accessing openai API endpoint: {}. Cause: ", endpointURL, e); + throw new RuntimeException(e); } - return "N/A"; } private static ImmutablePair extractPayloadChatCompletions(String jsonResponseChatCompletions) {