From d4ad1b09ca9a8f9bf6a143abaca4b298d31723e5 Mon Sep 17 00:00:00 2001 From: Taylor Gray Date: Wed, 21 Jun 2023 13:59:00 -0500 Subject: [PATCH] Elasticsearch client implementation with pit and no context search (#2910) Create Elasticsearch client, implement search and pit apis for ElasticsearchAccessor Signed-off-by: Taylor Gray --- .../opensearch-source/build.gradle | 1 + .../source/opensearch/OpenSearchSource.java | 5 +- .../worker/NoSearchContextWorker.java | 2 +- ...nSearchIndexPartitionCreationSupplier.java | 41 ++- .../source/opensearch/worker/PitWorker.java | 12 +- .../worker/client/ElasticsearchAccessor.java | 156 +++++++++- .../client/OpenSearchClientFactory.java | 293 ++++++++++++++++++ .../worker/client/SearchAccessorStrategy.java | 265 ++++------------ .../opensearch/OpenSearchSourceTest.java | 8 +- .../client/ElasticsearchAccessorTest.java | 281 +++++++++++++++++ .../client/OpenSearchClientFactoryTest.java | 122 ++++++++ ...rchIndexPartitionCreationSupplierTest.java | 97 ++++++ .../client/SearchAccessStrategyTest.java | 232 ++++++-------- 13 files changed, 1138 insertions(+), 377 deletions(-) create mode 100644 data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactory.java create mode 100644 data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessorTest.java create mode 100644 data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactoryTest.java diff --git a/data-prepper-plugins/opensearch-source/build.gradle b/data-prepper-plugins/opensearch-source/build.gradle index b974a9a32b..22a02beeda 100644 --- a/data-prepper-plugins/opensearch-source/build.gradle +++ b/data-prepper-plugins/opensearch-source/build.gradle @@ -14,6 +14,7 @@ dependencies { testImplementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' implementation 'org.opensearch.client:opensearch-java:2.5.0' implementation 'org.opensearch.client:opensearch-rest-client:2.7.0' + implementation 'co.elastic.clients:elasticsearch-java:7.17.0' implementation "org.apache.commons:commons-lang3:3.12.0" implementation('org.apache.maven:maven-artifact:3.0.3') { exclude group: 'org.codehaus.plexus' diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java index 779b8a3cbd..9a91db3ac1 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSource.java @@ -13,6 +13,7 @@ import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; import org.opensearch.dataprepper.model.source.coordinator.UsesSourceCoordination; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchClientFactory; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; @@ -41,7 +42,9 @@ public void start(final Buffer> buffer) { } private void startProcess(final OpenSearchSourceConfiguration openSearchSourceConfiguration, final Buffer> buffer) { - final SearchAccessorStrategy searchAccessorStrategy = SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier); + + final OpenSearchClientFactory openSearchClientFactory = OpenSearchClientFactory.create(awsCredentialsSupplier); + final SearchAccessorStrategy searchAccessorStrategy = SearchAccessorStrategy.create(openSearchSourceConfiguration, openSearchClientFactory); final SearchAccessor searchAccessor = searchAccessorStrategy.getSearchAccessor(); diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/NoSearchContextWorker.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/NoSearchContextWorker.java index 352d39588e..f63f99552c 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/NoSearchContextWorker.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/NoSearchContextWorker.java @@ -117,7 +117,7 @@ private void processIndex(final SourcePartition op } }); } catch (final Exception e) { - LOG.error("Received an exception while searching with PIT for index '{}'", indexName); + LOG.error("Received an exception while searching with no search context for index '{}'", indexName); throw new RuntimeException(e); } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/OpenSearchIndexPartitionCreationSupplier.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/OpenSearchIndexPartitionCreationSupplier.java index e4ca82a3bc..94f72efa33 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/OpenSearchIndexPartitionCreationSupplier.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/OpenSearchIndexPartitionCreationSupplier.java @@ -5,10 +5,11 @@ package org.opensearch.dataprepper.plugins.source.opensearch.worker; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.ElasticsearchException; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.OpenSearchException; import org.opensearch.client.opensearch.cat.IndicesResponse; -import org.opensearch.client.opensearch.cat.indices.IndicesRecord; import org.opensearch.dataprepper.model.source.coordinator.PartitionIdentifier; import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; import org.opensearch.dataprepper.plugins.source.opensearch.configuration.IndexParametersConfiguration; @@ -32,7 +33,9 @@ public class OpenSearchIndexPartitionCreationSupplier implements Function apply(final Map globalStateMap) if (Objects.nonNull(openSearchClient)) { return applyForOpenSearchClient(globalStateMap); + } else if (Objects.nonNull(elasticsearchClient)) { + return applyForElasticSearchClient(globalStateMap); } return Collections.emptyList(); @@ -70,13 +77,29 @@ private List applyForOpenSearchClient(final Map shouldIndexBeProcessed(osIndicesRecord.index())) + .map(indexRecord -> PartitionIdentifier.builder().withPartitionKey(indexRecord.index()).build()) + .collect(Collectors.toList()); + } + + private List applyForElasticSearchClient(final Map globalStateMap) { + co.elastic.clients.elasticsearch.cat.IndicesResponse indicesResponse; + try { + indicesResponse = elasticsearchClient.cat().indices(); + } catch (IOException | ElasticsearchException e) { + LOG.error("There was an exception when calling /_cat/indices to create new index partitions", e); + return Collections.emptyList(); + } + + return indicesResponse.valueBody().stream() + .filter(esIndicesRecord -> shouldIndexBeProcessed(esIndicesRecord.index())) .map(indexRecord -> PartitionIdentifier.builder().withPartitionKey(indexRecord.index()).build()) .collect(Collectors.toList()); } - private boolean shouldIndexBeProcessed(final IndicesRecord indicesRecord) { - if (Objects.isNull(indicesRecord.index())) { + private boolean shouldIndexBeProcessed(final String indexName) { + + if (Objects.isNull(indexName)) { return false; } @@ -87,16 +110,16 @@ private boolean shouldIndexBeProcessed(final IndicesRecord indicesRecord) { final List includedIndices = indexParametersConfiguration.getIncludedIndices(); final List excludedIndices = indexParametersConfiguration.getExcludedIndices(); - final boolean matchesIncludedPattern = includedIndices.isEmpty() || doesIndexMatchPattern(includedIndices, indicesRecord); - final boolean matchesExcludePattern = doesIndexMatchPattern(excludedIndices, indicesRecord); + final boolean matchesIncludedPattern = includedIndices.isEmpty() || doesIndexMatchPattern(includedIndices, indexName); + final boolean matchesExcludePattern = doesIndexMatchPattern(excludedIndices, indexName); return matchesIncludedPattern && !matchesExcludePattern; } - private boolean doesIndexMatchPattern(final List indices, final IndicesRecord indicesRecord) { + private boolean doesIndexMatchPattern(final List indices, final String indexName) { for (final OpenSearchIndex index : indices) { - final Matcher matcher = index.getIndexNamePattern().matcher(indicesRecord.index()); + final Matcher matcher = index.getIndexNamePattern().matcher(indexName); if (matcher.matches()) { return true; diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java index f4a2510322..73040240ed 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/PitWorker.java @@ -179,12 +179,12 @@ private OpenSearchIndexProgressState initializeProgressState() { } private List getSearchAfter(final OpenSearchIndexProgressState openSearchIndexProgressState, final SearchWithSearchAfterResults searchWithSearchAfterResults) { - if (Objects.isNull(searchWithSearchAfterResults) && Objects.isNull(openSearchIndexProgressState.getSearchAfter())) { - return null; - } - - if (Objects.isNull(searchWithSearchAfterResults) && Objects.nonNull(openSearchIndexProgressState.getSearchAfter())) { - return openSearchIndexProgressState.getSearchAfter(); + if (Objects.isNull(searchWithSearchAfterResults)) { + if (Objects.isNull(openSearchIndexProgressState.getSearchAfter())) { + return null; + } else { + return openSearchIndexProgressState.getSearchAfter(); + } } return searchWithSearchAfterResults.getNextSearchAfter(); diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java index 3e3b8b6794..97185081d9 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessor.java @@ -4,6 +4,26 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.ElasticsearchException; +import co.elastic.clients.elasticsearch._types.ScoreSort; +import co.elastic.clients.elasticsearch._types.SortOptions; +import co.elastic.clients.elasticsearch._types.SortOrder; +import co.elastic.clients.elasticsearch._types.Time; +import co.elastic.clients.elasticsearch._types.query_dsl.MatchAllQuery; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import co.elastic.clients.elasticsearch.core.ClosePointInTimeRequest; +import co.elastic.clients.elasticsearch.core.ClosePointInTimeResponse; +import co.elastic.clients.elasticsearch.core.OpenPointInTimeRequest; +import co.elastic.clients.elasticsearch.core.OpenPointInTimeResponse; +import co.elastic.clients.elasticsearch.core.SearchRequest; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.search.PointInTimeReference; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventType; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.exceptions.SearchContextLimitException; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePointInTimeRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePointInTimeResponse; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreateScrollRequest; @@ -13,32 +33,100 @@ import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.NoSearchContextSearchRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPointInTimeRequest; -import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchWithSearchAfterResults; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollRequest; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchScrollResponse; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchWithSearchAfterResults; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.MetadataKeyAttributes.DOCUMENT_ID_METADATA_ATTRIBUTE_NAME; +import static org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.MetadataKeyAttributes.INDEX_METADATA_ATTRIBUTE_NAME; public class ElasticsearchAccessor implements SearchAccessor, ClusterClientFactory { + + private static final Logger LOG = LoggerFactory.getLogger(ElasticsearchAccessor.class); + + static final String PIT_RESOURCE_LIMIT_ERROR_TYPE = "rejected_execution_exception"; + + private final ElasticsearchClient elasticsearchClient; + private final SearchContextType searchContextType; + + public ElasticsearchAccessor(final ElasticsearchClient elasticsearchClient, final SearchContextType searchContextType) { + this.elasticsearchClient = elasticsearchClient; + this.searchContextType = searchContextType; + } + @Override public SearchContextType getSearchContextType() { - // todo: implement - return null; + return searchContextType; } @Override public CreatePointInTimeResponse createPit(final CreatePointInTimeRequest createPointInTimeRequest) { - //todo: implement - return null; + + OpenPointInTimeResponse openPointInTimeResponse; + try { + openPointInTimeResponse = elasticsearchClient.openPointInTime(OpenPointInTimeRequest.of(request -> request + .keepAlive(Time.of(time -> time.time(createPointInTimeRequest.getKeepAlive()))) + .index(createPointInTimeRequest.getIndex()))); + } catch (final ElasticsearchException e) { + if (isDueToPitLimitExceeded(e)) { + throw new SearchContextLimitException(String.format("There was an error creating a new point in time for index '%s': %s", createPointInTimeRequest.getIndex(), + e.error().causedBy().reason())); + } + LOG.error("There was an error creating a point in time for Elasticsearch: ", e); + throw e; + } catch (final IOException e) { + LOG.error("There was an error creating a point in time for Elasticsearch: ", e); + throw new RuntimeException(e); + } + + return CreatePointInTimeResponse.builder() + .withPitId(openPointInTimeResponse.id()) + .withCreationTime(Instant.now().toEpochMilli()) + .build(); } @Override - public SearchWithSearchAfterResults searchWithPit(SearchPointInTimeRequest searchPointInTimeRequest) { - //todo: implement - return null; + public SearchWithSearchAfterResults searchWithPit(final SearchPointInTimeRequest searchPointInTimeRequest) { + final SearchRequest searchRequest = SearchRequest.of(builder -> { builder + .pit(PointInTimeReference.of(pit -> pit + .id(searchPointInTimeRequest.getPitId()) + .keepAlive(Time.of(time -> time.time(searchPointInTimeRequest.getKeepAlive()))))) + .size(searchPointInTimeRequest.getPaginationSize()) + .sort(SortOptions.of(sortOptionsBuilder -> sortOptionsBuilder.doc(ScoreSort.of(scoreSort -> scoreSort.order(SortOrder.Asc))))) + .query(Query.of(query -> query.matchAll(MatchAllQuery.of(matchAllQuery -> matchAllQuery)))); + + if (Objects.nonNull(searchPointInTimeRequest.getSearchAfter())) { + builder.searchAfter(searchPointInTimeRequest.getSearchAfter()); + } + return builder; + }); + + + return searchWithSearchAfter(searchRequest); } @Override public void deletePit(final DeletePointInTimeRequest deletePointInTimeRequest) { - //todo: implement + try { + final ClosePointInTimeResponse closePointInTimeResponse = elasticsearchClient.closePointInTime(ClosePointInTimeRequest.of(request -> request + .id(deletePointInTimeRequest.getPitId()))); + if (closePointInTimeResponse.succeeded()) { + LOG.debug("Successfully deleted point in time id {}", deletePointInTimeRequest.getPitId()); + } else { + LOG.warn("Point in time id {} was not deleted successfully. It will expire from keep-alive", deletePointInTimeRequest.getPitId()); + } + } catch (final IOException | RuntimeException e) { + LOG.error("There was an error deleting the point in time with id {} for Elasticsearch. It will expire from keep-alive: ", deletePointInTimeRequest.getPitId(), e); + } } @Override @@ -59,12 +147,56 @@ public void deleteScroll(DeleteScrollRequest deleteScrollRequest) { } @Override - public SearchWithSearchAfterResults searchWithoutSearchContext(NoSearchContextSearchRequest noSearchContextSearchRequest) { - return null; + public SearchWithSearchAfterResults searchWithoutSearchContext(final NoSearchContextSearchRequest noSearchContextSearchRequest) { + final SearchRequest searchRequest = SearchRequest.of(builder -> { + builder + .index(noSearchContextSearchRequest.getIndex()) + .size(noSearchContextSearchRequest.getPaginationSize()) + .sort(SortOptions.of(sortOptionsBuilder -> sortOptionsBuilder.doc(ScoreSort.of(scoreSort -> scoreSort.order(SortOrder.Asc))))) + .query(Query.of(query -> query.matchAll(MatchAllQuery.of(matchAllQuery -> matchAllQuery)))); + + if (Objects.nonNull(noSearchContextSearchRequest.getSearchAfter())) { + builder.searchAfter(noSearchContextSearchRequest.getSearchAfter()); + } + + return builder; + }); + + return searchWithSearchAfter(searchRequest); } @Override public Object getClient() { - return null; + return elasticsearchClient; + } + + private SearchWithSearchAfterResults searchWithSearchAfter(final SearchRequest searchRequest) { + + try { + final SearchResponse searchResponse = elasticsearchClient.search(searchRequest, ObjectNode.class); + + final List documents = searchResponse.hits().hits().stream() + .map(hit -> JacksonEvent.builder() + .withData(hit.source()) + .withEventMetadataAttributes(Map.of(DOCUMENT_ID_METADATA_ATTRIBUTE_NAME, hit.id(), INDEX_METADATA_ATTRIBUTE_NAME, hit.index())) + .withEventType(EventType.DOCUMENT.toString()).build()) + .collect(Collectors.toList()); + + final List nextSearchAfter = Objects.nonNull(searchResponse.hits().hits()) && !searchResponse.hits().hits().isEmpty() ? + searchResponse.hits().hits().get(searchResponse.hits().hits().size() - 1).sort() : + null; + + return SearchWithSearchAfterResults.builder() + .withDocuments(documents) + .withNextSearchAfter(nextSearchAfter) + .build(); + } catch (final IOException e) { + throw new RuntimeException(e); + } + } + + private boolean isDueToPitLimitExceeded(final ElasticsearchException e) { + return Objects.nonNull(e.error()) && Objects.nonNull(e.error().causedBy()) && Objects.nonNull(e.error().causedBy().type()) + && PIT_RESOURCE_LIMIT_ERROR_TYPE.equals(e.error().causedBy().type()); } } diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactory.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactory.java new file mode 100644 index 0000000000..d9e3a2f739 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactory.java @@ -0,0 +1,293 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.transport.ElasticsearchTransport; +import org.apache.http.Header; +import org.apache.http.HttpHost; +import org.apache.http.HttpResponseInterceptor; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.conn.ssl.TrustAllStrategy; +import org.apache.http.conn.ssl.TrustStrategy; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.apache.http.message.BasicHeader; +import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.ssl.SSLContexts; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.json.jackson.JacksonJsonpMapper; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.transport.OpenSearchTransport; +import org.opensearch.client.transport.aws.AwsSdk2Transport; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; +import org.opensearch.client.transport.rest_client.RestClientTransport; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyStore; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.util.List; +import java.util.Objects; + +public class OpenSearchClientFactory { + + private static final Logger LOG = LoggerFactory.getLogger(OpenSearchClientFactory.class); + + private static final String AOS_SERVICE_NAME = "es"; + + private final AwsCredentialsSupplier awsCredentialsSupplier; + + public static OpenSearchClientFactory create(final AwsCredentialsSupplier awsCredentialsSupplier) { + return new OpenSearchClientFactory(awsCredentialsSupplier); + } + + private OpenSearchClientFactory(final AwsCredentialsSupplier awsCredentialsSupplier) { + this.awsCredentialsSupplier = awsCredentialsSupplier; + } + + public OpenSearchClient provideOpenSearchClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + OpenSearchTransport transport; + if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) { + transport = createOpenSearchTransportForAws(openSearchSourceConfiguration); + } else { + final RestClient restClient = createOpenSearchRestClient(openSearchSourceConfiguration); + transport = createOpenSearchTransport(restClient); + } + return new OpenSearchClient(transport); + } + + public ElasticsearchClient provideElasticSearchClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final org.elasticsearch.client.RestClient restClientElasticsearch = createElasticSearchRestClient(openSearchSourceConfiguration); + final ElasticsearchTransport elasticsearchTransport = createElasticSearchTransport(restClientElasticsearch); + return new ElasticsearchClient(elasticsearchTransport); + } + + private OpenSearchTransport createOpenSearchTransport(final RestClient restClient) { + return new RestClientTransport(restClient, new JacksonJsonpMapper()); + } + + private OpenSearchTransport createOpenSearchTransportForAws(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion()) + .withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn()) + .withStsExternalId(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsExternalId()) + .withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides()) + .build()); + + return new AwsSdk2Transport(createSdkHttpClient(openSearchSourceConfiguration), + HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(), + AOS_SERVICE_NAME, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(), + AwsSdk2TransportOptions.builder() + .setCredentials(awsCredentialsProvider) + .setMapper(new JacksonJsonpMapper()) + .build()); + } + + private SdkHttpClient createSdkHttpClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder(); + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { + apacheHttpClientBuilder.connectionTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout()); + } + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { + apacheHttpClientBuilder.socketTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout()); + } + + attachSSLContext(apacheHttpClientBuilder, openSearchSourceConfiguration); + + return apacheHttpClientBuilder.build(); + } + + private RestClient createOpenSearchRestClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final List hosts = openSearchSourceConfiguration.getHosts(); + final HttpHost[] httpHosts = new HttpHost[hosts.size()]; + + int i = 0; + for (final String host : hosts) { + httpHosts[i] = HttpHost.create(host); + i++; + } + + final RestClientBuilder restClientBuilder = RestClient.builder(httpHosts); + + LOG.info("Using username and password for auth for the OpenSearch source"); + attachUsernamePassword(restClientBuilder, openSearchSourceConfiguration); + + setConnectAndSocketTimeout(restClientBuilder, openSearchSourceConfiguration); + + return restClientBuilder.build(); + } + + private ElasticsearchTransport createElasticSearchTransport(final org.elasticsearch.client.RestClient restClient) { + return new co.elastic.clients.transport.rest_client.RestClientTransport(restClient, new co.elastic.clients.json.jackson.JacksonJsonpMapper()); + } + + private org.elasticsearch.client.RestClient createElasticSearchRestClient(final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final List hosts = openSearchSourceConfiguration.getHosts(); + final HttpHost[] httpHosts = new HttpHost[hosts.size()]; + + int i = 0; + for (final String host : hosts) { + httpHosts[i] = HttpHost.create(host); + i++; + } + + final org.elasticsearch.client.RestClientBuilder restClientBuilder = org.elasticsearch.client.RestClient.builder(httpHosts); + + restClientBuilder.setDefaultHeaders(new Header[] { + new BasicHeader("Content-type", "application/json") + }); + + LOG.info("Using username and password for auth for the OpenSearch source"); + attachUsernamePassword(restClientBuilder, openSearchSourceConfiguration); + + setConnectAndSocketTimeout(restClientBuilder, openSearchSourceConfiguration); + + return restClientBuilder.build(); + } + + private void attachUsernamePassword(final RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, + new UsernamePasswordCredentials(openSearchSourceConfiguration.getUsername(), openSearchSourceConfiguration.getPassword())); + + restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> { + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + attachSSLContext(httpClientBuilder, openSearchSourceConfiguration); + return httpClientBuilder; + }); + } + + private void attachUsernamePassword(final org.elasticsearch.client.RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, + new UsernamePasswordCredentials(openSearchSourceConfiguration.getUsername(), openSearchSourceConfiguration.getPassword())); + + restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> { + httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); + attachSSLContext(httpClientBuilder, openSearchSourceConfiguration); + httpClientBuilder.addInterceptorLast( + (HttpResponseInterceptor) + (response, context) -> + response.addHeader("X-Elastic-Product", "Elasticsearch")); + return httpClientBuilder; + }); + } + + private void setConnectAndSocketTimeout(final RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + restClientBuilder.setRequestConfigCallback(requestConfigBuilder -> { + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { + requestConfigBuilder.setConnectTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout().toMillis()); + } + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { + requestConfigBuilder.setSocketTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout().toMillis()); + } + + return requestConfigBuilder; + }); + } + + private void setConnectAndSocketTimeout(final org.elasticsearch.client.RestClientBuilder restClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + restClientBuilder.setRequestConfigCallback(requestConfigBuilder -> { + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { + requestConfigBuilder.setConnectTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout().toMillis()); + } + + if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { + requestConfigBuilder.setSocketTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout().toMillis()); + } + + return requestConfigBuilder; + }); + } + + private void attachSSLContext(final ApacheHttpClient.Builder apacheHttpClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + TrustManager[] trustManagers = createTrustManagers(openSearchSourceConfiguration.getConnectionConfiguration().getCertPath()); + apacheHttpClientBuilder.tlsTrustManagersProvider(() -> trustManagers); + } + + private void attachSSLContext(final HttpAsyncClientBuilder httpClientBuilder, final OpenSearchSourceConfiguration openSearchSourceConfiguration) { + + final ConnectionConfiguration connectionConfiguration = openSearchSourceConfiguration.getConnectionConfiguration(); + final SSLContext sslContext = Objects.nonNull(connectionConfiguration.getCertPath()) ? getCAStrategy(connectionConfiguration.getCertPath()) : getTrustAllStrategy(); + httpClientBuilder.setSSLContext(sslContext); + + if (connectionConfiguration.isInsecure()) { + httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + } + } + + private static TrustManager[] createTrustManagers(final Path certPath) { + if (certPath != null) { + LOG.info("Using the cert provided in the config."); + try (InputStream certificateInputStream = Files.newInputStream(certPath)) { + final CertificateFactory factory = CertificateFactory.getInstance("X.509"); + final Certificate trustedCa = factory.generateCertificate(certificateInputStream); + final KeyStore trustStore = KeyStore.getInstance("pkcs12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", trustedCa); + + final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance("X509"); + trustManagerFactory.init(trustStore); + return trustManagerFactory.getTrustManagers(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } else { + return new TrustManager[] { new X509TrustAllManager() }; + } + } + + private SSLContext getCAStrategy(final Path certPath) { + LOG.info("Using the cert provided in the config."); + try { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + Certificate trustedCa; + try (InputStream is = Files.newInputStream(certPath)) { + trustedCa = factory.generateCertificate(is); + } + KeyStore trustStore = KeyStore.getInstance("pkcs12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", trustedCa); + SSLContextBuilder sslContextBuilder = SSLContexts.custom() + .loadTrustMaterial(trustStore, null); + return sslContextBuilder.build(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } + + private SSLContext getTrustAllStrategy() { + LOG.info("Using the trust all strategy"); + final TrustStrategy trustStrategy = new TrustAllStrategy(); + try { + return SSLContexts.custom().loadTrustMaterial(null, trustStrategy).build(); + } catch (Exception ex) { + throw new RuntimeException(ex.getMessage(), ex); + } + } +} diff --git a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java index f25c59b8f2..50e7414703 100644 --- a/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java +++ b/data-prepper-plugins/opensearch-source/src/main/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessorStrategy.java @@ -4,50 +4,19 @@ */ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; -import org.apache.http.HttpHost; -import org.apache.http.auth.AuthScope; -import org.apache.http.auth.UsernamePasswordCredentials; -import org.apache.http.client.CredentialsProvider; -import org.apache.http.conn.ssl.NoopHostnameVerifier; -import org.apache.http.conn.ssl.TrustAllStrategy; -import org.apache.http.conn.ssl.TrustStrategy; -import org.apache.http.impl.client.BasicCredentialsProvider; -import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; -import org.apache.http.ssl.SSLContextBuilder; -import org.apache.http.ssl.SSLContexts; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import org.apache.commons.lang3.tuple.Pair; import org.apache.maven.artifact.versioning.DefaultArtifactVersion; -import org.opensearch.client.RestClient; -import org.opensearch.client.RestClientBuilder; -import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.OpenSearchException; import org.opensearch.client.opensearch.core.InfoResponse; -import org.opensearch.client.transport.OpenSearchTransport; -import org.opensearch.client.transport.aws.AwsSdk2Transport; -import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; -import org.opensearch.client.transport.rest_client.RestClientTransport; -import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.client.util.MissingRequiredPropertyException; import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; -import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.security.KeyStore; -import java.security.cert.Certificate; -import java.security.cert.CertificateFactory; -import java.util.List; import java.util.Objects; /** @@ -59,22 +28,27 @@ public class SearchAccessorStrategy { private static final Logger LOG = LoggerFactory.getLogger(SearchAccessorStrategy.class); - private static final String AOS_SERVICE_NAME = "es"; static final String OPENSEARCH_DISTRIBUTION = "opensearch"; + static final String ELASTICSEARCH_DISTRIBUTION = "elasticsearch"; + static final String ELASTICSEARCH_OSS_BUILD_FLAVOR = "oss"; + static final String OPENDISTRO_DISTRIUBTION = "opendistro"; + private static final String OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF = "2.5.0"; + private static final String ELASTICSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF = "7.10.0"; + - private final AwsCredentialsSupplier awsCredentialsSupplier; + private final OpenSearchClientFactory openSearchClientFactory; private final OpenSearchSourceConfiguration openSearchSourceConfiguration; public static SearchAccessorStrategy create(final OpenSearchSourceConfiguration openSearchSourceConfiguration, - final AwsCredentialsSupplier awsCredentialsSupplier) { - return new SearchAccessorStrategy(openSearchSourceConfiguration, awsCredentialsSupplier); + final OpenSearchClientFactory openSearchClientFactory) { + return new SearchAccessorStrategy(openSearchSourceConfiguration, openSearchClientFactory); } private SearchAccessorStrategy(final OpenSearchSourceConfiguration openSearchSourceConfiguration, - final AwsCredentialsSupplier awsCredentialsSupplier) { - this.awsCredentialsSupplier = awsCredentialsSupplier; + final OpenSearchClientFactory openSearchClientFactory) { this.openSearchSourceConfiguration = openSearchSourceConfiguration; + this.openSearchClientFactory = openSearchClientFactory; } /** @@ -84,208 +58,93 @@ private SearchAccessorStrategy(final OpenSearchSourceConfiguration openSearchSou */ public SearchAccessor getSearchAccessor() { - OpenSearchTransport transport; - if (Objects.nonNull(openSearchSourceConfiguration.getAwsAuthenticationOptions())) { - transport = createOpenSearchTransportForAws(); - } else { - final RestClient restClient = createOpenSearchRestClient(); - transport = createOpenSearchTransport(restClient); - } - final OpenSearchClient openSearchClient = new OpenSearchClient(transport); + final OpenSearchClient openSearchClient = openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration); - InfoResponse infoResponse; + InfoResponse infoResponse = null; + + ElasticsearchClient elasticsearchClient = null; try { infoResponse = openSearchClient.info(); + } catch (final MissingRequiredPropertyException e) { + LOG.info("Detected Elasticsearch cluster. Constructing Elasticsearch client"); + elasticsearchClient = openSearchClientFactory.provideElasticSearchClient(openSearchSourceConfiguration); } catch (final IOException | OpenSearchException e) { throw new RuntimeException("There was an error looking up the OpenSearch cluster info: ", e); } - final String distribution = infoResponse.version().distribution(); - final String versionNumber = infoResponse.version().number(); + final Pair distributionAndVersion = getDistributionAndVersionNumber(infoResponse, elasticsearchClient); - if (!distribution.equals(OPENSEARCH_DISTRIBUTION)) { - throw new IllegalArgumentException(String.format("Only opensearch distributions are supported at this time. The cluster distribution being used is '%s'", distribution)); - } + final String distribution = distributionAndVersion.getLeft(); + final String versionNumber = distributionAndVersion.getRight(); + + validateDistribution(distribution); SearchContextType searchContextType; if (Objects.nonNull(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType())) { LOG.info("Using search_context_type set in the config: '{}'", openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType().toString().toLowerCase()); - validateSearchContextTypeOverride(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType(), versionNumber); + validateSearchContextTypeOverride(openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType(), distribution, versionNumber); searchContextType = openSearchSourceConfiguration.getSearchConfiguration().getSearchContextType(); - } else if (versionSupportsPointInTimeForOpenSearch(versionNumber)) { - LOG.info("OpenSearch version {} detected. Point in time APIs will be used to search documents", versionNumber); + } else if (versionSupportsPointInTime(distribution, versionNumber)) { + LOG.info("{} distribution and version {} detected. Point in time APIs will be used to search documents", distribution, versionNumber); searchContextType = SearchContextType.POINT_IN_TIME; } else { - LOG.info("OpenSearch version {} detected. Scroll contexts will be used to search documents. " + - "Upgrade your cluster to at least version {} to use Point in Time APIs instead of scroll.", versionNumber, OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); + LOG.info("{} distribution, version {} detected. Scroll contexts will be used to search documents. " + + "Upgrade your cluster to at least version {} to use Point in Time APIs instead of scroll.", distribution, versionNumber, + distribution.equals(ELASTICSEARCH_DISTRIBUTION) ? ELASTICSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF : OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); searchContextType = SearchContextType.SCROLL; } - return new OpenSearchAccessor(openSearchClient, searchContextType); - } - - private RestClient createOpenSearchRestClient() { - final List hosts = openSearchSourceConfiguration.getHosts(); - final HttpHost[] httpHosts = new HttpHost[hosts.size()]; - - int i = 0; - for (final String host : hosts) { - httpHosts[i] = HttpHost.create(host); - i++; + if (Objects.isNull(elasticsearchClient)) { + return new OpenSearchAccessor(openSearchClient, searchContextType); } - final RestClientBuilder restClientBuilder = RestClient.builder(httpHosts); - - LOG.info("Using username and password for auth for the OpenSearch source"); - attachUsernamePassword(restClientBuilder); - - setConnectAndSocketTimeout(restClientBuilder); - - return restClientBuilder.build(); + return new ElasticsearchAccessor(elasticsearchClient, searchContextType); } - private void attachSSLContext(final ApacheHttpClient.Builder apacheHttpClientBuilder) { - TrustManager[] trustManagers = createTrustManagers(openSearchSourceConfiguration.getConnectionConfiguration().getCertPath()); - apacheHttpClientBuilder.tlsTrustManagersProvider(() -> trustManagers); - } - - private void attachSSLContext(final HttpAsyncClientBuilder httpClientBuilder) { + private void validateSearchContextTypeOverride(final SearchContextType searchContextType, final String distribution, final String version) { - final ConnectionConfiguration connectionConfiguration = openSearchSourceConfiguration.getConnectionConfiguration(); - final SSLContext sslContext = Objects.nonNull(connectionConfiguration.getCertPath()) ? getCAStrategy(connectionConfiguration.getCertPath()) : getTrustAllStrategy(); - httpClientBuilder.setSSLContext(sslContext); - - if (connectionConfiguration.isInsecure()) { - httpClientBuilder.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + if (searchContextType.equals(SearchContextType.POINT_IN_TIME) && !versionSupportsPointInTime(distribution, version)) { + throw new IllegalArgumentException( + String.format("A search_context_type of point_in_time is only supported on OpenSearch versions %s and above. " + + "The version of the OpenSearch cluster passed is %s. Elasticsearch clusters with build-flavor %s do not support point in time", + distribution.startsWith(ELASTICSEARCH_DISTRIBUTION) ? ELASTICSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF : OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF, + version, ELASTICSEARCH_OSS_BUILD_FLAVOR)); } } - private static TrustManager[] createTrustManagers(final Path certPath) { - if (certPath != null) { - LOG.info("Using the cert provided in the config."); - try (InputStream certificateInputStream = Files.newInputStream(certPath)) { - final CertificateFactory factory = CertificateFactory.getInstance("X.509"); - final Certificate trustedCa = factory.generateCertificate(certificateInputStream); - final KeyStore trustStore = KeyStore.getInstance("pkcs12"); - trustStore.load(null, null); - trustStore.setCertificateEntry("ca", trustedCa); + private boolean versionSupportsPointInTime(final String distribution, final String version) { + final DefaultArtifactVersion actualVersion = new DefaultArtifactVersion(version); - final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance("X509"); - trustManagerFactory.init(trustStore); - return trustManagerFactory.getTrustManagers(); - } catch (Exception ex) { - throw new RuntimeException(ex.getMessage(), ex); + DefaultArtifactVersion cutoffVersion; + if (distribution.startsWith(ELASTICSEARCH_DISTRIBUTION)) { + if (distribution.endsWith(ELASTICSEARCH_OSS_BUILD_FLAVOR)) { + return false; } + cutoffVersion = new DefaultArtifactVersion(ELASTICSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); } else { - return new TrustManager[] { new X509TrustAllManager() }; - } - } - - private void attachUsernamePassword(final RestClientBuilder restClientBuilder) { - final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(AuthScope.ANY, - new UsernamePasswordCredentials(openSearchSourceConfiguration.getUsername(), openSearchSourceConfiguration.getPassword())); - - restClientBuilder.setHttpClientConfigCallback(httpClientBuilder -> { - httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider); - attachSSLContext(httpClientBuilder); - return httpClientBuilder; - }); - } - - private void setConnectAndSocketTimeout(final RestClientBuilder restClientBuilder) { - restClientBuilder.setRequestConfigCallback(requestConfigBuilder -> { - if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { - requestConfigBuilder.setConnectTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout().toMillis()); - } - - if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { - requestConfigBuilder.setSocketTimeout((int) openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout().toMillis()); - } - - return requestConfigBuilder; - }); - } - - private OpenSearchTransport createOpenSearchTransport(final RestClient restClient) { - return new RestClientTransport(restClient, new JacksonJsonpMapper()); - } - - private OpenSearchTransport createOpenSearchTransportForAws() { - final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() - .withRegion(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion()) - .withStsRoleArn(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsRoleArn()) - .withStsExternalId(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsExternalId()) - .withStsHeaderOverrides(openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsStsHeaderOverrides()) - .build()); - - return new AwsSdk2Transport(createSdkHttpClient(), - HttpHost.create(openSearchSourceConfiguration.getHosts().get(0)).getHostName(), - AOS_SERVICE_NAME, openSearchSourceConfiguration.getAwsAuthenticationOptions().getAwsRegion(), - AwsSdk2TransportOptions.builder() - .setCredentials(awsCredentialsProvider) - .setMapper(new JacksonJsonpMapper()) - .build()); - } - - private SdkHttpClient createSdkHttpClient() { - final ApacheHttpClient.Builder apacheHttpClientBuilder = ApacheHttpClient.builder(); - - if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout())) { - apacheHttpClientBuilder.connectionTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getConnectTimeout()); - } - - if (Objects.nonNull(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout())) { - apacheHttpClientBuilder.socketTimeout(openSearchSourceConfiguration.getConnectionConfiguration().getSocketTimeout()); + cutoffVersion = new DefaultArtifactVersion(OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); } - - attachSSLContext(apacheHttpClientBuilder); - - return apacheHttpClientBuilder.build(); + return actualVersion.compareTo(cutoffVersion) >= 0; } - private SSLContext getCAStrategy(final Path certPath) { - LOG.info("Using the cert provided in the config."); - try { - CertificateFactory factory = CertificateFactory.getInstance("X.509"); - Certificate trustedCa; - try (InputStream is = Files.newInputStream(certPath)) { - trustedCa = factory.generateCertificate(is); - } - KeyStore trustStore = KeyStore.getInstance("pkcs12"); - trustStore.load(null, null); - trustStore.setCertificateEntry("ca", trustedCa); - SSLContextBuilder sslContextBuilder = SSLContexts.custom() - .loadTrustMaterial(trustStore, null); - return sslContextBuilder.build(); - } catch (Exception ex) { - throw new RuntimeException(ex.getMessage(), ex); + private Pair getDistributionAndVersionNumber(final InfoResponse infoResponseOpenSearch, final ElasticsearchClient elasticsearchClient) { + if (Objects.nonNull(infoResponseOpenSearch)) { + return Pair.of(infoResponseOpenSearch.version().distribution(), infoResponseOpenSearch.version().number()); } - } - private SSLContext getTrustAllStrategy() { - LOG.info("Using the trust all strategy"); - final TrustStrategy trustStrategy = new TrustAllStrategy(); try { - return SSLContexts.custom().loadTrustMaterial(null, trustStrategy).build(); - } catch (Exception ex) { - throw new RuntimeException(ex.getMessage(), ex); + final co.elastic.clients.elasticsearch.core.InfoResponse infoResponseElasticsearch = elasticsearchClient.info(); + return Pair.of(ELASTICSEARCH_DISTRIBUTION + "-" + infoResponseElasticsearch.version().buildFlavor(), infoResponseElasticsearch.version().number()); + } catch (final Exception e) { + throw new RuntimeException("Unable to call info API using the elasticsearch client", e); } } - private void validateSearchContextTypeOverride(final SearchContextType searchContextType, final String version) { - - if (searchContextType.equals(SearchContextType.POINT_IN_TIME) && !versionSupportsPointInTimeForOpenSearch(version)) { - throw new IllegalArgumentException( - String.format("A search_context_type of point_in_time is only supported on OpenSearch versions %s and above. " + - "The version of the OpenSearch cluster passed is %s", OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF, version)); + private void validateDistribution(final String distribution) { + if (!distribution.equals(OPENSEARCH_DISTRIBUTION) && !distribution.startsWith(ELASTICSEARCH_DISTRIBUTION) && !distribution.equals(OPENDISTRO_DISTRIUBTION)) { + throw new IllegalArgumentException(String.format("Only %s, %s, or %s distributions are supported at this time. The cluster distribution being used is '%s'", + OPENSEARCH_DISTRIBUTION, OPENDISTRO_DISTRIUBTION, ELASTICSEARCH_DISTRIBUTION, distribution)); } } - - private boolean versionSupportsPointInTimeForOpenSearch(final String version) { - final DefaultArtifactVersion cutoffVersion = new DefaultArtifactVersion(OPENSEARCH_POINT_IN_TIME_SUPPORT_VERSION_CUTOFF); - final DefaultArtifactVersion actualVersion = new DefaultArtifactVersion(version); - return actualVersion.compareTo(cutoffVersion) >= 0; - } } diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java index 68affcfc64..e5f19cffeb 100644 --- a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/OpenSearchSourceTest.java @@ -15,6 +15,7 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.coordinator.SourceCoordinator; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchClientFactory; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessor; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy; @@ -35,6 +36,9 @@ public class OpenSearchSourceTest { @Mock private OpenSearchService openSearchService; + @Mock + private OpenSearchClientFactory openSearchClientFactory; + @Mock private SearchAccessorStrategy searchAccessorStrategy; @@ -66,8 +70,10 @@ void start_with_non_null_buffer_does_not_throw() { objectUnderTest.setSourceCoordinator(sourceCoordinator); try (final MockedStatic searchAccessorStrategyMockedStatic = mockStatic(SearchAccessorStrategy.class); + final MockedStatic openSearchClientFactoryMockedStatic = mockStatic(OpenSearchClientFactory.class); final MockedStatic openSearchServiceMockedStatic = mockStatic(OpenSearchService.class)) { - searchAccessorStrategyMockedStatic.when(() -> SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier)).thenReturn(searchAccessorStrategy); + openSearchClientFactoryMockedStatic.when(() -> OpenSearchClientFactory.create(awsCredentialsSupplier)).thenReturn(openSearchClientFactory); + searchAccessorStrategyMockedStatic.when(() -> SearchAccessorStrategy.create(openSearchSourceConfiguration, openSearchClientFactory)).thenReturn(searchAccessorStrategy); openSearchServiceMockedStatic.when(() -> OpenSearchService.createOpenSearchService(searchAccessor, sourceCoordinator, openSearchSourceConfiguration, buffer)) .thenReturn(openSearchService); diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessorTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessorTest.java new file mode 100644 index 0000000000..20cbd617a1 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/ElasticsearchAccessorTest.java @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.ElasticsearchException; +import co.elastic.clients.elasticsearch._types.ErrorCause; +import co.elastic.clients.elasticsearch.core.ClosePointInTimeRequest; +import co.elastic.clients.elasticsearch.core.ClosePointInTimeResponse; +import co.elastic.clients.elasticsearch.core.OpenPointInTimeRequest; +import co.elastic.clients.elasticsearch.core.OpenPointInTimeResponse; +import co.elastic.clients.elasticsearch.core.SearchRequest; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.elasticsearch.core.search.HitsMetadata; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.exceptions.SearchContextLimitException; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePointInTimeRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.CreatePointInTimeResponse; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.DeletePointInTimeRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.NoSearchContextSearchRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchPointInTimeRequest; +import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchWithSearchAfterResults; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.opensearch.worker.client.OpenSearchAccessor.PIT_RESOURCE_LIMIT_ERROR_TYPE; + +@ExtendWith(MockitoExtension.class) +public class ElasticsearchAccessorTest { + + @Mock + private ElasticsearchClient elasticSearchClient; + + private SearchAccessor createObjectUnderTest() { + return new ElasticsearchAccessor(elasticSearchClient, SearchContextType.POINT_IN_TIME); + } + + @Test + void create_pit_returns_expected_create_point_in_time_response() throws IOException { + final String indexName = UUID.randomUUID().toString(); + final String keepAlive = UUID.randomUUID().toString(); + + final CreatePointInTimeRequest createPointInTimeRequest = mock(CreatePointInTimeRequest.class); + when(createPointInTimeRequest.getIndex()).thenReturn(indexName); + when(createPointInTimeRequest.getKeepAlive()).thenReturn(keepAlive); + + final String pitId = UUID.randomUUID().toString(); + final OpenPointInTimeResponse createPitResponse = mock(OpenPointInTimeResponse.class); + when(createPitResponse.id()).thenReturn(pitId); + + when(elasticSearchClient.openPointInTime(any(OpenPointInTimeRequest.class))).thenReturn(createPitResponse); + + final CreatePointInTimeResponse createPointInTimeResponse = createObjectUnderTest().createPit(createPointInTimeRequest); + assertThat(createPointInTimeResponse, notNullValue()); + assertThat(createPointInTimeResponse.getPitCreationTime(), lessThanOrEqualTo(Instant.now().toEpochMilli())); + assertThat(createPointInTimeResponse.getPitId(), equalTo(pitId)); + } + + @Test + void create_pit_with_exception_for_pit_limit_throws_SearchContextLimitException() throws IOException { + final String indexName = UUID.randomUUID().toString(); + final String keepAlive = UUID.randomUUID().toString(); + + final CreatePointInTimeRequest createPointInTimeRequest = mock(CreatePointInTimeRequest.class); + when(createPointInTimeRequest.getIndex()).thenReturn(indexName); + when(createPointInTimeRequest.getKeepAlive()).thenReturn(keepAlive); + + final ElasticsearchException elasticsearchException = mock(ElasticsearchException.class); + final ErrorCause errorCause = mock(ErrorCause.class); + final ErrorCause rootCause = mock(ErrorCause.class); + when(rootCause.type()).thenReturn(PIT_RESOURCE_LIMIT_ERROR_TYPE); + when(rootCause.reason()).thenReturn(UUID.randomUUID().toString()); + when(errorCause.causedBy()).thenReturn(rootCause); + when(elasticsearchException.error()).thenReturn(errorCause); + + when(elasticSearchClient.openPointInTime(any(OpenPointInTimeRequest.class))).thenThrow(elasticsearchException); + + assertThrows(SearchContextLimitException.class, () -> createObjectUnderTest().createPit(createPointInTimeRequest)); + } + + @Test + void createPit_throws_Elasticsearch_exception_throws_that_exception() throws IOException { + final String indexName = UUID.randomUUID().toString(); + final String keepAlive = UUID.randomUUID().toString(); + + final CreatePointInTimeRequest createPointInTimeRequest = mock(CreatePointInTimeRequest.class); + when(createPointInTimeRequest.getIndex()).thenReturn(indexName); + when(createPointInTimeRequest.getKeepAlive()).thenReturn(keepAlive); + + final ElasticsearchException openSearchException = mock(ElasticsearchException.class); + final ErrorCause errorCause = mock(ErrorCause.class); + when(errorCause.causedBy()).thenReturn(null); + when(openSearchException.error()).thenReturn(errorCause); + + when(elasticSearchClient.openPointInTime(any(OpenPointInTimeRequest.class))).thenThrow(openSearchException); + + assertThrows(ElasticsearchException.class, () -> createObjectUnderTest().createPit(createPointInTimeRequest)); + } + + @Test + void createPit_throws_runtime_exception_throws_IO_Exception() throws IOException { + final String indexName = UUID.randomUUID().toString(); + final String keepAlive = UUID.randomUUID().toString(); + + final CreatePointInTimeRequest createPointInTimeRequest = mock(CreatePointInTimeRequest.class); + when(createPointInTimeRequest.getIndex()).thenReturn(indexName); + when(createPointInTimeRequest.getKeepAlive()).thenReturn(keepAlive); + + when(elasticSearchClient.openPointInTime(any(OpenPointInTimeRequest.class))).thenThrow(IOException.class); + + assertThrows(RuntimeException.class, () -> createObjectUnderTest().createPit(createPointInTimeRequest)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void delete_pit_with_no_exception_does_not_throw(final boolean successful) throws IOException { + final String pitId = UUID.randomUUID().toString(); + + final DeletePointInTimeRequest deletePointInTimeRequest = mock(DeletePointInTimeRequest.class); + when(deletePointInTimeRequest.getPitId()).thenReturn(pitId); + + final ClosePointInTimeResponse deletePitResponse = mock(ClosePointInTimeResponse.class); + when(deletePitResponse.succeeded()).thenReturn(successful); + + when(elasticSearchClient.closePointInTime(any(ClosePointInTimeRequest.class))).thenReturn(deletePitResponse); + + createObjectUnderTest().deletePit(deletePointInTimeRequest); + } + + @Test + void delete_pit_does_not_throw_during_opensearch_exception() throws IOException { + final String pitId = UUID.randomUUID().toString(); + + final DeletePointInTimeRequest deletePointInTimeRequest = mock(DeletePointInTimeRequest.class); + when(deletePointInTimeRequest.getPitId()).thenReturn(pitId); + + when(elasticSearchClient.closePointInTime(any(ClosePointInTimeRequest.class))).thenThrow(ElasticsearchException.class); + + createObjectUnderTest().deletePit(deletePointInTimeRequest); + } + + @Test + void delete_pit_does_not_throw_exception_when_client_throws_IOException() throws IOException { + final String pitId = UUID.randomUUID().toString(); + + final DeletePointInTimeRequest deletePointInTimeRequest = mock(DeletePointInTimeRequest.class); + when(deletePointInTimeRequest.getPitId()).thenReturn(pitId); + + when(elasticSearchClient.closePointInTime(any(ClosePointInTimeRequest.class))).thenThrow(IOException.class); + + createObjectUnderTest().deletePit(deletePointInTimeRequest); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void search_with_pit_returns_expected_SearchPointInTimeResponse(final boolean hasSearchAfter) throws IOException { + final String pitId = UUID.randomUUID().toString(); + final Integer paginationSize = new Random().nextInt(); + final List searchAfter = Collections.singletonList(UUID.randomUUID().toString()); + + final SearchPointInTimeRequest searchPointInTimeRequest = mock(SearchPointInTimeRequest.class); + when(searchPointInTimeRequest.getPitId()).thenReturn(pitId); + when(searchPointInTimeRequest.getKeepAlive()).thenReturn("1m"); + when(searchPointInTimeRequest.getPaginationSize()).thenReturn(paginationSize); + + if (hasSearchAfter) { + when(searchPointInTimeRequest.getSearchAfter()).thenReturn(searchAfter); + } else { + when(searchPointInTimeRequest.getSearchAfter()).thenReturn(null); + } + + final SearchResponse searchResponse = mock(SearchResponse.class); + final HitsMetadata hitsMetadata = mock(HitsMetadata.class); + final List> hits = new ArrayList<>(); + final Hit firstHit = mock(Hit.class); + when(firstHit.id()).thenReturn(UUID.randomUUID().toString()); + when(firstHit.index()).thenReturn(UUID.randomUUID().toString()); + when(firstHit.source()).thenReturn(mock(ObjectNode.class)); + + final Hit secondHit = mock(Hit.class); + when(secondHit.id()).thenReturn(UUID.randomUUID().toString()); + when(secondHit.index()).thenReturn(UUID.randomUUID().toString()); + when(secondHit.source()).thenReturn(mock(ObjectNode.class)); + when(secondHit.sort()).thenReturn(searchAfter); + + hits.add(firstHit); + hits.add(secondHit); + + when(hitsMetadata.hits()).thenReturn(hits); + when(searchResponse.hits()).thenReturn(hitsMetadata); + + final ArgumentCaptor searchRequestArgumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + + when(elasticSearchClient.search(searchRequestArgumentCaptor.capture(), eq(ObjectNode.class))).thenReturn(searchResponse); + + final SearchWithSearchAfterResults searchWithSearchAfterResults = createObjectUnderTest().searchWithPit(searchPointInTimeRequest); + + assertThat(searchWithSearchAfterResults, notNullValue()); + assertThat(searchWithSearchAfterResults.getDocuments(), notNullValue()); + assertThat(searchWithSearchAfterResults.getDocuments().size(), equalTo(2)); + + assertThat(searchWithSearchAfterResults.getNextSearchAfter(), equalTo(secondHit.sort())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void search_without_search_context_returns_expected_SearchPointInTimeResponse(final boolean hasSearchAfter) throws IOException { + final Integer paginationSize = new Random().nextInt(); + final String index = UUID.randomUUID().toString(); + final List searchAfter = Collections.singletonList(UUID.randomUUID().toString()); + + final NoSearchContextSearchRequest noSearchContextSearchRequest = mock(NoSearchContextSearchRequest.class); + when(noSearchContextSearchRequest.getPaginationSize()).thenReturn(paginationSize); + when(noSearchContextSearchRequest.getIndex()).thenReturn(index); + + if (hasSearchAfter) { + when(noSearchContextSearchRequest.getSearchAfter()).thenReturn(searchAfter); + } else { + when(noSearchContextSearchRequest.getSearchAfter()).thenReturn(null); + } + + final SearchResponse searchResponse = mock(SearchResponse.class); + final HitsMetadata hitsMetadata = mock(HitsMetadata.class); + final List> hits = new ArrayList<>(); + final Hit firstHit = mock(Hit.class); + when(firstHit.id()).thenReturn(UUID.randomUUID().toString()); + when(firstHit.index()).thenReturn(index); + when(firstHit.source()).thenReturn(mock(ObjectNode.class)); + + final Hit secondHit = mock(Hit.class); + when(secondHit.id()).thenReturn(UUID.randomUUID().toString()); + when(secondHit.index()).thenReturn(index); + when(secondHit.source()).thenReturn(mock(ObjectNode.class)); + when(secondHit.sort()).thenReturn(searchAfter); + + hits.add(firstHit); + hits.add(secondHit); + + when(hitsMetadata.hits()).thenReturn(hits); + when(searchResponse.hits()).thenReturn(hitsMetadata); + + final ArgumentCaptor searchRequestArgumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + + when(elasticSearchClient.search(searchRequestArgumentCaptor.capture(), eq(ObjectNode.class))).thenReturn(searchResponse); + + final SearchWithSearchAfterResults searchWithSearchAfterResults = createObjectUnderTest().searchWithoutSearchContext(noSearchContextSearchRequest); + + assertThat(searchWithSearchAfterResults, notNullValue()); + assertThat(searchWithSearchAfterResults.getDocuments(), notNullValue()); + assertThat(searchWithSearchAfterResults.getDocuments().size(), equalTo(2)); + + assertThat(searchWithSearchAfterResults.getNextSearchAfter(), equalTo(secondHit.sort())); + } +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactoryTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactoryTest.java new file mode 100644 index 0000000000..cc811625d1 --- /dev/null +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchClientFactoryTest.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; +import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class OpenSearchClientFactoryTest { + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private OpenSearchSourceConfiguration openSearchSourceConfiguration; + + @Mock + private ConnectionConfiguration connectionConfiguration; + + @BeforeEach + void setup() { + when(openSearchSourceConfiguration.getHosts()).thenReturn(List.of("http://localhost:9200")); + when(openSearchSourceConfiguration.getConnectionConfiguration()).thenReturn(connectionConfiguration); + } + + private OpenSearchClientFactory createObjectUnderTest() { + return OpenSearchClientFactory.create(awsCredentialsSupplier); + } + + @Test + void provideOpenSearchClient_with_username_and_password() { + final String username = UUID.randomUUID().toString(); + final String password = UUID.randomUUID().toString(); + when(openSearchSourceConfiguration.getUsername()).thenReturn(username); + when(openSearchSourceConfiguration.getPassword()).thenReturn(password); + + when(connectionConfiguration.getCertPath()).thenReturn(null); + when(connectionConfiguration.getSocketTimeout()).thenReturn(null); + when(connectionConfiguration.getConnectTimeout()).thenReturn(null); + when(connectionConfiguration.isInsecure()).thenReturn(true); + + when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(null); + + final OpenSearchClient openSearchClient = createObjectUnderTest().provideOpenSearchClient(openSearchSourceConfiguration); + assertThat(openSearchClient, notNullValue()); + + verifyNoInteractions(awsCredentialsSupplier); + + } + + @Test + void provideElasticSearchClient_with_username_and_password() { + final String username = UUID.randomUUID().toString(); + final String password = UUID.randomUUID().toString(); + when(openSearchSourceConfiguration.getUsername()).thenReturn(username); + when(openSearchSourceConfiguration.getPassword()).thenReturn(password); + + when(connectionConfiguration.getCertPath()).thenReturn(null); + when(connectionConfiguration.getSocketTimeout()).thenReturn(null); + when(connectionConfiguration.getConnectTimeout()).thenReturn(null); + when(connectionConfiguration.isInsecure()).thenReturn(true); + + final ElasticsearchClient elasticsearchClient = createObjectUnderTest().provideElasticSearchClient(openSearchSourceConfiguration); + assertThat(elasticsearchClient, notNullValue()); + + verifyNoInteractions(awsCredentialsSupplier); + } + + @Test + void provideOpenSearchClient_with_aws_auth() { + when(connectionConfiguration.getCertPath()).thenReturn(null); + when(connectionConfiguration.getSocketTimeout()).thenReturn(null); + when(connectionConfiguration.getConnectTimeout()).thenReturn(null); + + final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); + when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1); + final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role"; + when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn); + when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); + when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + + final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); + final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); + when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider); + + final OpenSearchClient openSearchClient = createObjectUnderTest().provideOpenSearchClient(openSearchSourceConfiguration); + assertThat(openSearchClient, notNullValue()); + + final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); + assertThat(awsCredentialsOptions, notNullValue()); + assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); + assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); + assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + } +} diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchIndexPartitionCreationSupplierTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchIndexPartitionCreationSupplierTest.java index 0a424b0610..3c0ed9628a 100644 --- a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchIndexPartitionCreationSupplierTest.java +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/OpenSearchIndexPartitionCreationSupplierTest.java @@ -5,6 +5,9 @@ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.ElasticsearchException; +import co.elastic.clients.elasticsearch.cat.ElasticsearchCatClient; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -49,6 +52,9 @@ public class OpenSearchIndexPartitionCreationSupplierTest { @Mock private OpenSearchClient openSearchClient; + @Mock + private ElasticsearchClient elasticsearchClient; + private OpenSearchIndexPartitionCreationSupplier createObjectUnderTest() { return new OpenSearchIndexPartitionCreationSupplier(openSearchSourceConfiguration, clusterClientFactory); } @@ -75,6 +81,21 @@ void apply_with_opensearch_client_cat_indices_throws_exception_returns_empty_lis assertThat(partitionIdentifierList.isEmpty(), equalTo(true)); } + @ParameterizedTest + @MethodSource("elasticsearchCatIndicesExceptions") + void apply_with_elasticsearch_client_cat_indices_throws_exception_returns_empty_list(final Class exception) throws IOException { + when(clusterClientFactory.getClient()).thenReturn(elasticsearchClient); + + final ElasticsearchCatClient elasticsearchCatClient = mock(ElasticsearchCatClient.class); + when(elasticsearchCatClient.indices()).thenThrow(exception); + when(elasticsearchClient.cat()).thenReturn(elasticsearchCatClient); + + final List partitionIdentifierList = createObjectUnderTest().apply(Collections.emptyMap()); + + assertThat(partitionIdentifierList, notNullValue()); + assertThat(partitionIdentifierList.isEmpty(), equalTo(true)); + } + @Test void apply_with_opensearch_client_with_no_indices_return_empty_list() throws IOException { when(clusterClientFactory.getClient()).thenReturn(openSearchClient); @@ -91,6 +112,22 @@ void apply_with_opensearch_client_with_no_indices_return_empty_list() throws IOE assertThat(partitionIdentifierList.isEmpty(), equalTo(true)); } + @Test + void apply_with_elasticsearch_client_with_no_indices_return_empty_list() throws IOException { + when(clusterClientFactory.getClient()).thenReturn(elasticsearchClient); + + final ElasticsearchCatClient elasticsearchCatClient = mock(ElasticsearchCatClient.class); + final co.elastic.clients.elasticsearch.cat.IndicesResponse indicesResponse = mock(co.elastic.clients.elasticsearch.cat.IndicesResponse.class); + when(indicesResponse.valueBody()).thenReturn(Collections.emptyList()); + when(elasticsearchCatClient.indices()).thenReturn(indicesResponse); + when(elasticsearchClient.cat()).thenReturn(elasticsearchCatClient); + + final List partitionIdentifierList = createObjectUnderTest().apply(Collections.emptyMap()); + + assertThat(partitionIdentifierList, notNullValue()); + assertThat(partitionIdentifierList.isEmpty(), equalTo(true)); + } + @Test void apply_with_opensearch_client_with_indices_filters_them_correctly() throws IOException { when(clusterClientFactory.getClient()).thenReturn(openSearchClient); @@ -146,8 +183,68 @@ void apply_with_opensearch_client_with_indices_filters_them_correctly() throws I assertThat(partitionIdentifierList, notNullValue()); } + @Test + void apply_with_elasticsearch_client_with_indices_filters_them_correctly() throws IOException { + when(clusterClientFactory.getClient()).thenReturn(elasticsearchClient); + + final ElasticsearchCatClient elasticsearchCatClient = mock(ElasticsearchCatClient.class); + final co.elastic.clients.elasticsearch.cat.IndicesResponse indicesResponse = mock(co.elastic.clients.elasticsearch.cat.IndicesResponse.class); + + final IndexParametersConfiguration indexParametersConfiguration = mock(IndexParametersConfiguration.class); + + final List includedIndices = new ArrayList<>(); + final OpenSearchIndex includeIndex = mock(OpenSearchIndex.class); + final String includePattern = "my-pattern-[a-c].*"; + when(includeIndex.getIndexNamePattern()).thenReturn(Pattern.compile(includePattern)); + includedIndices.add(includeIndex); + + final List excludedIndices = new ArrayList<>(); + final OpenSearchIndex excludeIndex = mock(OpenSearchIndex.class); + final String excludePattern = "my-pattern-[a-c]-exclude"; + when(excludeIndex.getIndexNamePattern()).thenReturn(Pattern.compile(excludePattern)); + excludedIndices.add(excludeIndex); + + final OpenSearchIndex secondExcludeIndex = mock(OpenSearchIndex.class); + final String secondExcludePattern = "second-exclude-.*"; + when(secondExcludeIndex.getIndexNamePattern()).thenReturn(Pattern.compile(secondExcludePattern)); + excludedIndices.add(secondExcludeIndex); + + when(indexParametersConfiguration.getIncludedIndices()).thenReturn(includedIndices); + when(indexParametersConfiguration.getExcludedIndices()).thenReturn(excludedIndices); + when(openSearchSourceConfiguration.getIndexParametersConfiguration()).thenReturn(indexParametersConfiguration); + + final List indicesRecords = new ArrayList<>(); + final co.elastic.clients.elasticsearch.cat.indices.IndicesRecord includedIndex = mock(co.elastic.clients.elasticsearch.cat.indices.IndicesRecord.class); + when(includedIndex.index()).thenReturn("my-pattern-a-include"); + final co.elastic.clients.elasticsearch.cat.indices.IndicesRecord excludedIndex = mock(co.elastic.clients.elasticsearch.cat.indices.IndicesRecord.class); + when(excludedIndex.index()).thenReturn("second-exclude-test"); + final co.elastic.clients.elasticsearch.cat.indices.IndicesRecord includedAndThenExcluded = mock(co.elastic.clients.elasticsearch.cat.indices.IndicesRecord.class); + when(includedAndThenExcluded.index()).thenReturn("my-pattern-a-exclude"); + final co.elastic.clients.elasticsearch.cat.indices.IndicesRecord neitherIncludedOrExcluded = mock(co.elastic.clients.elasticsearch.cat.indices.IndicesRecord.class); + when(neitherIncludedOrExcluded.index()).thenReturn("random-index"); + + indicesRecords.add(includedIndex); + indicesRecords.add(excludedIndex); + indicesRecords.add(includedAndThenExcluded); + indicesRecords.add(neitherIncludedOrExcluded); + + when(indicesResponse.valueBody()).thenReturn(indicesRecords); + + when(elasticsearchCatClient.indices()).thenReturn(indicesResponse); + when(elasticsearchClient.cat()).thenReturn(elasticsearchCatClient); + + final List partitionIdentifierList = createObjectUnderTest().apply(Collections.emptyMap()); + + assertThat(partitionIdentifierList, notNullValue()); + } + private static Stream opensearchCatIndicesExceptions() { return Stream.of(Arguments.of(IOException.class), Arguments.of(OpenSearchException.class)); } + + private static Stream elasticsearchCatIndicesExceptions() { + return Stream.of(Arguments.of(IOException.class), + Arguments.of(ElasticsearchException.class)); + } } diff --git a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java index 582c9f4652..e30ec3559e 100644 --- a/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java +++ b/data-prepper-plugins/opensearch-source/src/test/java/org/opensearch/dataprepper/plugins/source/opensearch/worker/client/SearchAccessStrategyTest.java @@ -5,38 +5,29 @@ package org.opensearch.dataprepper.plugins.source.opensearch.worker.client; -import org.junit.jupiter.api.BeforeEach; +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.ElasticsearchVersionInfo; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.MockedConstruction; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.opensearch.OpenSearchClient; import org.opensearch.client.opensearch._types.OpenSearchVersionInfo; import org.opensearch.client.opensearch.core.InfoResponse; -import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; -import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.client.util.MissingRequiredPropertyException; import org.opensearch.dataprepper.plugins.source.opensearch.OpenSearchSourceConfiguration; -import org.opensearch.dataprepper.plugins.source.opensearch.configuration.AwsAuthenticationConfiguration; -import org.opensearch.dataprepper.plugins.source.opensearch.configuration.ConnectionConfiguration; import org.opensearch.dataprepper.plugins.source.opensearch.configuration.SearchConfiguration; import org.opensearch.dataprepper.plugins.source.opensearch.worker.client.model.SearchContextType; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import java.util.Collections; -import java.util.List; -import java.util.UUID; +import java.io.IOException; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.opensearch.dataprepper.plugins.source.opensearch.worker.client.SearchAccessorStrategy.OPENSEARCH_DISTRIBUTION; @@ -44,79 +35,97 @@ public class SearchAccessStrategyTest { @Mock - private AwsCredentialsSupplier awsCredentialsSupplier; + private OpenSearchClientFactory openSearchClientFactory; @Mock private OpenSearchSourceConfiguration openSearchSourceConfiguration; - @Mock - private ConnectionConfiguration connectionConfiguration; - - @BeforeEach - void setup() { - when(openSearchSourceConfiguration.getHosts()).thenReturn(List.of("http://localhost:9200")); - when(openSearchSourceConfiguration.getConnectionConfiguration()).thenReturn(connectionConfiguration); - } - private SearchAccessorStrategy createObjectUnderTest() { - return SearchAccessorStrategy.create(openSearchSourceConfiguration, awsCredentialsSupplier); + return SearchAccessorStrategy.create(openSearchSourceConfiguration, openSearchClientFactory); } @ParameterizedTest @ValueSource(strings = {"2.5.0", "2.6.1", "3.0.0"}) - void testHappyPath_with_username_and_password_and_insecure_for_different_point_in_time_versions_for_opensearch(final String osVersion) { - final String username = UUID.randomUUID().toString(); - final String password = UUID.randomUUID().toString(); - when(openSearchSourceConfiguration.getUsername()).thenReturn(username); - when(openSearchSourceConfiguration.getPassword()).thenReturn(password); - - when(connectionConfiguration.getCertPath()).thenReturn(null); - when(connectionConfiguration.getSocketTimeout()).thenReturn(null); - when(connectionConfiguration.getConnectTimeout()).thenReturn(null); - when(connectionConfiguration.isInsecure()).thenReturn(true); + void testHappyPath_for_different_point_in_time_versions_for_opensearch(final String osVersion) throws IOException { final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); when(searchConfiguration.getSearchContextType()).thenReturn(null); when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration); - when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(null); - final InfoResponse infoResponse = mock(InfoResponse.class); final OpenSearchVersionInfo openSearchVersionInfo = mock(OpenSearchVersionInfo.class); when(openSearchVersionInfo.distribution()).thenReturn(OPENSEARCH_DISTRIBUTION); when(openSearchVersionInfo.number()).thenReturn(osVersion); when(infoResponse.version()).thenReturn(openSearchVersionInfo); - try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, - (clientMock, context) -> { - when(clientMock.info()).thenReturn(infoResponse); - })) { + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); + + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.POINT_IN_TIME)); + } + + @ParameterizedTest + @ValueSource(strings = {"7.10.2", "8.1.1", "7.10.0"}) + void testHappyPath_for_different_point_in_time_versions_for_elasticsearch(final String esVersion) throws IOException { + + final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); + when(searchConfiguration.getSearchContextType()).thenReturn(null); + when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration); + + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenThrow(MissingRequiredPropertyException.class); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); + + final ElasticsearchClient elasticsearchClient = mock(ElasticsearchClient.class); + + final co.elastic.clients.elasticsearch.core.InfoResponse infoResponse = mock(co.elastic.clients.elasticsearch.core.InfoResponse.class); + final ElasticsearchVersionInfo elasticsearchVersionInfo = mock(ElasticsearchVersionInfo.class); + when(elasticsearchVersionInfo.buildFlavor()).thenReturn("default"); + when(elasticsearchVersionInfo.number()).thenReturn(esVersion); + when(infoResponse.version()).thenReturn(elasticsearchVersionInfo); + + when(elasticsearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideElasticSearchClient(openSearchSourceConfiguration)).thenReturn(elasticsearchClient); + + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.POINT_IN_TIME)); + + } + + @ParameterizedTest + @CsvSource(value = {"6.3.0,default", "7.9.0,default", "0.3.2,default", "7.10.2,oss"}) + void search_context_type_set_to_point_in_time_with_invalid_version_throws_IllegalArgumentException_for_elasticsearch(final String esVersion, final String esBuildFlavor) throws IOException { + + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenThrow(MissingRequiredPropertyException.class); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); - final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); - assertThat(searchAccessor, notNullValue()); - assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.POINT_IN_TIME)); + final ElasticsearchClient elasticsearchClient = mock(ElasticsearchClient.class); - final List constructedClients = openSearchClientMockedConstruction.constructed(); - assertThat(constructedClients.size(), equalTo(1)); - } + final co.elastic.clients.elasticsearch.core.InfoResponse infoResponse = mock(co.elastic.clients.elasticsearch.core.InfoResponse.class); + final ElasticsearchVersionInfo elasticsearchVersionInfo = mock(ElasticsearchVersionInfo.class); + when(elasticsearchVersionInfo.buildFlavor()).thenReturn(esBuildFlavor); + when(elasticsearchVersionInfo.number()).thenReturn(esVersion); + when(infoResponse.version()).thenReturn(elasticsearchVersionInfo); - verifyNoInteractions(awsCredentialsSupplier); + final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); + when(searchConfiguration.getSearchContextType()).thenReturn(SearchContextType.POINT_IN_TIME); + when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration); + when(elasticsearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideElasticSearchClient(openSearchSourceConfiguration)).thenReturn(elasticsearchClient); + + + assertThrows(IllegalArgumentException.class, () -> createObjectUnderTest().getSearchAccessor()); } @ParameterizedTest @ValueSource(strings = {"1.3.0", "2.4.9", "0.3.2"}) - void testHappyPath_with_aws_credentials_for_different_scroll_versions_for_opensearch(final String osVersion) { - when(connectionConfiguration.getCertPath()).thenReturn(null); - when(connectionConfiguration.getSocketTimeout()).thenReturn(null); - when(connectionConfiguration.getConnectTimeout()).thenReturn(null); - - final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); - when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1); - final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role"; - when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn); - when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); - when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + void testHappyPath_with_for_different_scroll_versions_for_opensearch(final String osVersion) throws IOException { final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); when(searchConfiguration.getSearchContextType()).thenReturn(null); @@ -128,43 +137,18 @@ void testHappyPath_with_aws_credentials_for_different_scroll_versions_for_opense when(openSearchVersionInfo.number()).thenReturn(osVersion); when(infoResponse.version()).thenReturn(openSearchVersionInfo); - final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); - final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); - when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider); - - try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, - (clientMock, context) -> { - when(clientMock.info()).thenReturn(infoResponse); - })) { - - final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); - assertThat(searchAccessor, notNullValue()); - assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.SCROLL)); - - final List constructedClients = openSearchClientMockedConstruction.constructed(); - assertThat(constructedClients.size(), equalTo(1)); - } + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); - final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); - assertThat(awsCredentialsOptions, notNullValue()); - assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); - assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); - assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.SCROLL)); } @ParameterizedTest @ValueSource(strings = {"1.3.0", "2.4.9", "0.3.2"}) - void search_context_type_set_to_point_in_time_with_invalid_version_throws_IllegalArgumentException(final String osVersion) { - when(connectionConfiguration.getCertPath()).thenReturn(null); - when(connectionConfiguration.getSocketTimeout()).thenReturn(null); - when(connectionConfiguration.getConnectTimeout()).thenReturn(null); - - final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); - when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1); - final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role"; - when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn); - when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); - when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + void search_context_type_set_to_point_in_time_with_invalid_version_throws_IllegalArgumentException_for_opensearch(final String osVersion) throws IOException { final InfoResponse infoResponse = mock(InfoResponse.class); final OpenSearchVersionInfo openSearchVersionInfo = mock(OpenSearchVersionInfo.class); @@ -172,45 +156,20 @@ void search_context_type_set_to_point_in_time_with_invalid_version_throws_Illega when(openSearchVersionInfo.number()).thenReturn(osVersion); when(infoResponse.version()).thenReturn(openSearchVersionInfo); + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); + final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); when(searchConfiguration.getSearchContextType()).thenReturn(SearchContextType.POINT_IN_TIME); when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration); - final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); - final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); - when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider); - - try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, - (clientMock, context) -> { - when(clientMock.info()).thenReturn(infoResponse); - })) { - - assertThrows(IllegalArgumentException.class, () -> createObjectUnderTest().getSearchAccessor()); - - final List constructedClients = openSearchClientMockedConstruction.constructed(); - assertThat(constructedClients.size(), equalTo(1)); - } - - final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); - assertThat(awsCredentialsOptions, notNullValue()); - assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); - assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); - assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + assertThrows(IllegalArgumentException.class, () -> createObjectUnderTest().getSearchAccessor()); } @ParameterizedTest @ValueSource(strings = {"1.3.0", "2.4.9", "2.5.0"}) - void search_context_type_set_to_none_uses_that_search_context_regardless_of_version(final String osVersion) { - when(connectionConfiguration.getCertPath()).thenReturn(null); - when(connectionConfiguration.getSocketTimeout()).thenReturn(null); - when(connectionConfiguration.getConnectTimeout()).thenReturn(null); - - final AwsAuthenticationConfiguration awsAuthenticationConfiguration = mock(AwsAuthenticationConfiguration.class); - when(awsAuthenticationConfiguration.getAwsRegion()).thenReturn(Region.US_EAST_1); - final String stsRoleArn = "arn:aws:iam::123456789012:role/my-role"; - when(awsAuthenticationConfiguration.getAwsStsRoleArn()).thenReturn(stsRoleArn); - when(awsAuthenticationConfiguration.getAwsStsHeaderOverrides()).thenReturn(Collections.emptyMap()); - when(openSearchSourceConfiguration.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationConfiguration); + void search_context_type_set_to_none_uses_that_search_context_regardless_of_version(final String osVersion) throws IOException { final InfoResponse infoResponse = mock(InfoResponse.class); final OpenSearchVersionInfo openSearchVersionInfo = mock(OpenSearchVersionInfo.class); @@ -218,31 +177,16 @@ void search_context_type_set_to_none_uses_that_search_context_regardless_of_vers when(openSearchVersionInfo.number()).thenReturn(osVersion); when(infoResponse.version()).thenReturn(openSearchVersionInfo); + final OpenSearchClient openSearchClient = mock(OpenSearchClient.class); + when(openSearchClient.info()).thenReturn(infoResponse); + when(openSearchClientFactory.provideOpenSearchClient(openSearchSourceConfiguration)).thenReturn(openSearchClient); + final SearchConfiguration searchConfiguration = mock(SearchConfiguration.class); when(searchConfiguration.getSearchContextType()).thenReturn(SearchContextType.NONE); when(openSearchSourceConfiguration.getSearchConfiguration()).thenReturn(searchConfiguration); - final ArgumentCaptor awsCredentialsOptionsArgumentCaptor = ArgumentCaptor.forClass(AwsCredentialsOptions.class); - final AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class); - when(awsCredentialsSupplier.getProvider(awsCredentialsOptionsArgumentCaptor.capture())).thenReturn(awsCredentialsProvider); - - try (MockedConstruction openSearchClientMockedConstruction = mockConstruction(OpenSearchClient.class, - (clientMock, context) -> { - when(clientMock.info()).thenReturn(infoResponse); - })) { - - final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); - assertThat(searchAccessor, notNullValue()); - assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.NONE)); - - final List constructedClients = openSearchClientMockedConstruction.constructed(); - assertThat(constructedClients.size(), equalTo(1)); - } - - final AwsCredentialsOptions awsCredentialsOptions = awsCredentialsOptionsArgumentCaptor.getValue(); - assertThat(awsCredentialsOptions, notNullValue()); - assertThat(awsCredentialsOptions.getRegion(), equalTo(Region.US_EAST_1)); - assertThat(awsCredentialsOptions.getStsHeaderOverrides(), equalTo(Collections.emptyMap())); - assertThat(awsCredentialsOptions.getStsRoleArn(), equalTo(stsRoleArn)); + final SearchAccessor searchAccessor = createObjectUnderTest().getSearchAccessor(); + assertThat(searchAccessor, notNullValue()); + assertThat(searchAccessor.getSearchContextType(), equalTo(SearchContextType.NONE)); } }