Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace usages of ThreadContext.stashContext with pluginSubject.runAs #715

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on ho
### Documentation
### Maintenance
### Refactoring
- Replace usages of ThreadContext.stashContext with pluginSubject.runAs ([#715](https://github.com/opensearch-project/geospatial/pull/715))
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.opensearch.geospatial.ip2geo.common.Ip2GeoSettings;
import org.opensearch.geospatial.ip2geo.jobscheduler.Datasource;
import org.opensearch.geospatial.ip2geo.jobscheduler.DatasourceExtension;
import org.opensearch.geospatial.shared.StashedThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -85,7 +84,7 @@ public void createIndexIfNotExists(final StepListener<Void> stepListener) {
}
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(DatasourceExtension.JOB_INDEX_NAME).mapping(getIndexMapping())
.settings(DatasourceExtension.INDEX_SETTING);
StashedThreadContext.run(client, () -> client.admin().indices().create(createIndexRequest, new ActionListener<>() {
client.admin().indices().create(createIndexRequest, new ActionListener<>() {
@Override
public void onResponse(final CreateIndexResponse createIndexResponse) {
stepListener.onResponse(null);
Expand All @@ -100,7 +99,7 @@ public void onFailure(final Exception e) {
}
stepListener.onFailure(e);
}
}));
});
}

private String getIndexMapping() {
Expand All @@ -122,19 +121,17 @@ private String getIndexMapping() {
*/
public IndexResponse updateDatasource(final Datasource datasource) {
datasource.setLastUpdateTime(Instant.now());
return StashedThreadContext.run(client, () -> {
try {
return client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME)
.setId(datasource.getName())
.setOpType(DocWriteRequest.OpType.INDEX)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS))
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
} catch (IOException e) {
throw new RuntimeException(e);
}
});
try {
return client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME)
.setId(datasource.getName())
.setOpType(DocWriteRequest.OpType.INDEX)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS))
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

/**
Expand All @@ -148,7 +145,7 @@ public void updateDatasource(final List<Datasource> datasources, final ActionLis
datasource.setLastUpdateTime(Instant.now());
return datasource;
}).map(this::toIndexRequest).forEach(indexRequest -> bulkRequest.add(indexRequest));
StashedThreadContext.run(client, () -> client.bulk(bulkRequest, listener));
client.bulk(bulkRequest, listener);
}

private IndexRequest toIndexRequest(Datasource datasource) {
Expand All @@ -173,18 +170,16 @@ private IndexRequest toIndexRequest(Datasource datasource) {
*/
public void putDatasource(final Datasource datasource, final ActionListener listener) {
datasource.setLastUpdateTime(Instant.now());
StashedThreadContext.run(client, () -> {
try {
client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME)
.setId(datasource.getName())
.setOpType(DocWriteRequest.OpType.CREATE)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS))
.execute(listener);
} catch (IOException e) {
new RuntimeException(e);
}
});
try {
client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME)
.setId(datasource.getName())
.setOpType(DocWriteRequest.OpType.CREATE)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS))
.execute(listener);
} catch (IOException e) {
new RuntimeException(e);
}
}

/**
Expand Down Expand Up @@ -220,7 +215,7 @@ public Datasource getDatasource(final String name) throws IOException {
GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name);
GetResponse response;
try {
response = StashedThreadContext.run(client, () -> client.get(request).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT)));
response = client.get(request).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
if (response.isExists() == false) {
log.error("Datasource[{}] does not exist in an index[{}]", name, DatasourceExtension.JOB_INDEX_NAME);
return null;
Expand All @@ -245,7 +240,7 @@ public Datasource getDatasource(final String name) throws IOException {
*/
public void getDatasource(final String name, final ActionListener<Datasource> actionListener) {
GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name);
StashedThreadContext.run(client, () -> client.get(request, new ActionListener<>() {
client.get(request, new ActionListener<>() {
@Override
public void onResponse(final GetResponse response) {
if (response.isExists() == false) {
Expand All @@ -269,7 +264,7 @@ public void onResponse(final GetResponse response) {
public void onFailure(final Exception e) {
actionListener.onFailure(e);
}
}));
});
}

/**
Expand All @@ -278,42 +273,33 @@ public void onFailure(final Exception e) {
* @param actionListener the action listener
*/
public void getDatasources(final String[] names, final ActionListener<List<Datasource>> actionListener) {
StashedThreadContext.run(
client,
() -> client.prepareMultiGet()
.add(DatasourceExtension.JOB_INDEX_NAME, names)
.execute(createGetDataSourceQueryActionLister(MultiGetResponse.class, actionListener))
);
client.prepareMultiGet()
.add(DatasourceExtension.JOB_INDEX_NAME, names)
.execute(createGetDataSourceQueryActionLister(MultiGetResponse.class, actionListener));
}

/**
* Get all datasources up to {@code MAX_SIZE} from an index {@code DatasourceExtension.JOB_INDEX_NAME}
* @param actionListener the action listener
*/
public void getAllDatasources(final ActionListener<List<Datasource>> actionListener) {
StashedThreadContext.run(
client,
() -> client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME)
.setQuery(QueryBuilders.matchAllQuery())
.setPreference(Preference.PRIMARY.type())
.setSize(MAX_SIZE)
.execute(createGetDataSourceQueryActionLister(SearchResponse.class, actionListener))
);
client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME)
.setQuery(QueryBuilders.matchAllQuery())
.setPreference(Preference.PRIMARY.type())
.setSize(MAX_SIZE)
.execute(createGetDataSourceQueryActionLister(SearchResponse.class, actionListener));
}

/**
* Get all datasources up to {@code MAX_SIZE} from an index {@code DatasourceExtension.JOB_INDEX_NAME}
*/
public List<Datasource> getAllDatasources() {
SearchResponse response = StashedThreadContext.run(
client,
() -> client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME)
.setQuery(QueryBuilders.matchAllQuery())
.setPreference(Preference.PRIMARY.type())
.setSize(MAX_SIZE)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT))
);
SearchResponse response = client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME)
.setQuery(QueryBuilders.matchAllQuery())
.setPreference(Preference.PRIMARY.type())
.setSize(MAX_SIZE)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));

List<BytesReference> bytesReferences = toBytesReferences(response);
return bytesReferences.stream().map(bytesRef -> toDatasource(bytesRef)).collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import org.opensearch.geospatial.ip2geo.common.Ip2GeoSettings;
import org.opensearch.geospatial.ip2geo.common.URLDenyListChecker;
import org.opensearch.geospatial.shared.Constants;
import org.opensearch.geospatial.shared.StashedThreadContext;
import org.opensearch.index.query.QueryBuilders;

import lombok.NonNull;
Expand Down Expand Up @@ -117,24 +116,19 @@ public void createIndexIfNotExists(final String indexName) {
}
final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(INDEX_SETTING_TO_CREATE)
.mapping(getIndexMapping());
StashedThreadContext.run(
client,
() -> client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT))
);
client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
}

private void freezeIndex(final String indexName) {
TimeValue timeout = clusterSettings.get(Ip2GeoSettings.TIMEOUT);
StashedThreadContext.run(client, () -> {
client.admin().indices().prepareForceMerge(indexName).setMaxNumSegments(1).execute().actionGet(timeout);
client.admin().indices().prepareRefresh(indexName).execute().actionGet(timeout);
client.admin()
.indices()
.prepareUpdateSettings(indexName)
.setSettings(INDEX_SETTING_TO_FREEZE)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
});
client.admin().indices().prepareForceMerge(indexName).setMaxNumSegments(1).execute().actionGet(timeout);
client.admin().indices().prepareRefresh(indexName).execute().actionGet(timeout);
client.admin()
.indices()
.prepareUpdateSettings(indexName)
.setSettings(INDEX_SETTING_TO_FREEZE)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));
}

/**
Expand Down Expand Up @@ -249,15 +243,12 @@ public XContentBuilder createDocument(final String[] fields, final String[] valu
* @return geoIP data
*/
public Map<String, Object> getGeoIpData(final String indexName, final String ip) {
SearchResponse response = StashedThreadContext.run(
client,
() -> client.prepareSearch(indexName)
.setSize(1)
.setQuery(QueryBuilders.termQuery(IP_RANGE_FIELD_NAME, ip))
.setPreference(Preference.LOCAL.type())
.setRequestCache(true)
.get(clusterSettings.get(Ip2GeoSettings.TIMEOUT))
);
SearchResponse response = client.prepareSearch(indexName)
.setSize(1)
.setQuery(QueryBuilders.termQuery(IP_RANGE_FIELD_NAME, ip))
.setPreference(Preference.LOCAL.type())
.setRequestCache(true)
.get(clusterSettings.get(Ip2GeoSettings.TIMEOUT));

if (response.getHits().getHits().length == 0) {
return Collections.emptyMap();
Expand Down Expand Up @@ -297,7 +288,7 @@ public void putGeoIpData(
indexRequest.id(record.get(0));
bulkRequest.add(indexRequest);
if (iterator.hasNext() == false || bulkRequest.requests().size() == batchSize) {
BulkResponse response = StashedThreadContext.run(client, () -> client.bulk(bulkRequest).actionGet(timeout));
BulkResponse response = client.bulk(bulkRequest).actionGet(timeout);
if (response.hasFailures()) {
throw new OpenSearchException(
"error occurred while ingesting GeoIP data in {} with an error {}",
Expand Down Expand Up @@ -334,15 +325,12 @@ public void deleteIp2GeoDataIndex(final List<String> indices) {
);
}

AcknowledgedResponse response = StashedThreadContext.run(
client,
() -> client.admin()
.indices()
.prepareDelete(indices.toArray(new String[0]))
.setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT))
);
AcknowledgedResponse response = client.admin()
.indices()
.prepareDelete(indices.toArray(new String[0]))
.setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN)
.execute()
.actionGet(clusterSettings.get(Ip2GeoSettings.TIMEOUT));

if (response.isAcknowledged() == false) {
throw new OpenSearchException("failed to delete data[{}] in datasource", String.join(",", indices));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,15 @@ public String getType() {
public static final class Factory implements Processor.Factory {
private static final ParameterValidator VALIDATOR = new ParameterValidator();
private final IngestService ingestService;
private final DatasourceDao datasourceDao;
private final GeoIpDataDao geoIpDataDao;
private final Ip2GeoCachedDao ip2GeoCachedDao;
private DatasourceDao datasourceDao;
private GeoIpDataDao geoIpDataDao;
private Ip2GeoCachedDao ip2GeoCachedDao;

public Factory(
final IngestService ingestService,
final DatasourceDao datasourceDao,
final GeoIpDataDao geoIpDataDao,
final Ip2GeoCachedDao ip2GeoCachedDao
) {
public Factory(final IngestService ingestService) {
this.ingestService = ingestService;
}

public void initialize(final DatasourceDao datasourceDao, final GeoIpDataDao geoIpDataDao, final Ip2GeoCachedDao ip2GeoCachedDao) {
this.datasourceDao = datasourceDao;
this.geoIpDataDao = geoIpDataDao;
this.ip2GeoCachedDao = ip2GeoCachedDao;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,20 @@
import org.opensearch.geospatial.rest.action.upload.geojson.RestUploadGeoJSONAction;
import org.opensearch.geospatial.search.aggregations.bucket.geogrid.GeoHexGrid;
import org.opensearch.geospatial.search.aggregations.bucket.geogrid.GeoHexGridAggregationBuilder;
import org.opensearch.geospatial.shared.RunAsSubjectClient;
import org.opensearch.geospatial.stats.upload.RestUploadStatsAction;
import org.opensearch.geospatial.stats.upload.UploadStats;
import org.opensearch.geospatial.stats.upload.UploadStatsAction;
import org.opensearch.geospatial.stats.upload.UploadStatsTransportAction;
import org.opensearch.identity.PluginSubject;
import org.opensearch.index.IndexModule;
import org.opensearch.index.mapper.Mapper;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.ingest.Processor;
import org.opensearch.jobscheduler.spi.utils.LockService;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ClusterPlugin;
import org.opensearch.plugins.IdentityAwarePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.MapperPlugin;
import org.opensearch.plugins.Plugin;
Expand Down Expand Up @@ -109,30 +112,33 @@ public class GeospatialPlugin extends Plugin
MapperPlugin,
SearchPlugin,
SystemIndexPlugin,
ClusterPlugin {
ClusterPlugin,
IdentityAwarePlugin {
private Ip2GeoCachedDao ip2GeoCachedDao;
private DatasourceDao datasourceDao;
private GeoIpDataDao geoIpDataDao;
private Ip2GeoProcessor.Factory ip2geoProcessor;
private URLDenyListChecker urlDenyListChecker;
private ClusterService clusterService;
private Ip2GeoLockService ip2GeoLockService;
private Ip2GeoExecutor ip2GeoExecutor;
private DatasourceUpdateService datasourceUpdateService;
private RunAsSubjectClient pluginClient;

@Override
public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings settings) {
return List.of(new SystemIndexDescriptor(IP2GEO_DATA_INDEX_NAME_PREFIX, "System index used for Ip2Geo data"));
return List.of(
new SystemIndexDescriptor(IP2GEO_DATA_INDEX_NAME_PREFIX + "*", "System index pattern used for Ip2Geo data"),
new SystemIndexDescriptor(DatasourceExtension.JOB_INDEX_NAME, "System index used for Ip2Geo job")
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
this.urlDenyListChecker = new URLDenyListChecker(parameters.ingestService.getClusterService().getClusterSettings());
this.datasourceDao = new DatasourceDao(parameters.client, parameters.ingestService.getClusterService());
this.geoIpDataDao = new GeoIpDataDao(parameters.ingestService.getClusterService(), parameters.client, urlDenyListChecker);
this.ip2GeoCachedDao = new Ip2GeoCachedDao(parameters.ingestService.getClusterService(), datasourceDao, geoIpDataDao);
this.ip2geoProcessor = new Ip2GeoProcessor.Factory(parameters.ingestService);
return MapBuilder.<String, Processor.Factory>newMapBuilder()
.put(FeatureProcessor.TYPE, new FeatureProcessor.Factory())
.put(Ip2GeoProcessor.TYPE, new Ip2GeoProcessor.Factory(parameters.ingestService, datasourceDao, geoIpDataDao, ip2GeoCachedDao))
.put(Ip2GeoProcessor.TYPE, ip2geoProcessor)
.immutableMap();
}

Expand Down Expand Up @@ -179,6 +185,14 @@ public Collection<Object> createComponents(
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.clusterService = clusterService;
this.pluginClient = new RunAsSubjectClient(client);
this.urlDenyListChecker = new URLDenyListChecker(clusterService.getClusterSettings());
this.datasourceDao = new DatasourceDao(pluginClient, clusterService);
this.geoIpDataDao = new GeoIpDataDao(clusterService, pluginClient, urlDenyListChecker);
this.ip2GeoCachedDao = new Ip2GeoCachedDao(clusterService, datasourceDao, geoIpDataDao);
if (this.ip2geoProcessor != null) {
this.ip2geoProcessor.initialize(datasourceDao, geoIpDataDao, ip2GeoCachedDao);
}
this.datasourceUpdateService = new DatasourceUpdateService(clusterService, datasourceDao, geoIpDataDao, urlDenyListChecker);
this.ip2GeoExecutor = new Ip2GeoExecutor(threadPool);
this.ip2GeoLockService = new Ip2GeoLockService(clusterService);
Expand Down Expand Up @@ -285,6 +299,13 @@ public void onNodeStarted(DiscoveryNode localNode) {
.initialize(this.clusterService, this.datasourceUpdateService, this.ip2GeoExecutor, this.datasourceDao, this.ip2GeoLockService);
}

@Override
public void assignSubject(PluginSubject pluginSubject) {
if (this.pluginClient != null) {
this.pluginClient.setSubject(pluginSubject);
}
}

public static class GuiceHolder implements LifecycleComponent {

private static LockService lockService;
Expand Down
Loading
Loading