diff --git a/samples/Samples/NeuralSearch/NeuralSearchSample.cs b/samples/Samples/NeuralSearch/NeuralSearchSample.cs
new file mode 100644
index 0000000000..dc459d5418
--- /dev/null
+++ b/samples/Samples/NeuralSearch/NeuralSearchSample.cs
@@ -0,0 +1,266 @@
+/* SPDX-License-Identifier: Apache-2.0
+*
+* The OpenSearch Contributors require contributions made to
+* this file be licensed under the Apache-2.0 license or a
+* compatible open source license.
+*/
+
+using System.Diagnostics;
+using OpenSearch.Client;
+using OpenSearch.Net;
+
+namespace Samples.NeuralSearch;
+
+///
+/// Sample based off of the Neural Search Tutorial
+///
+public class NeuralSearchSample : Sample
+{
+ private const string SampleName = "neural-search";
+ private const string ResourceNamePrefix = "csharp-" + SampleName;
+ private const string MlModelGroupName = ResourceNamePrefix + "-model-group";
+ private const string IngestPipelineName = ResourceNamePrefix + "-ingest-pipeline";
+ private const string IndexName = ResourceNamePrefix + "-index";
+
+ private string? _modelGroupId;
+ private string? _modelRegistrationTaskId;
+ private string? _modelId;
+ private string? _modelDeployTaskId;
+ private bool _putIngestPipeline;
+ private bool _createdIndex;
+
+ public NeuralSearchSample() : base(SampleName, "A sample demonstrating how to perform a neural search query") { }
+
+ public class NeuralSearchDoc
+ {
+ [PropertyName("id")] public string? Id { get; set; }
+ [PropertyName("text")] public string? Text { get; set; }
+ [PropertyName("passage_embedding")] public float[]? PassageEmbedding { get; set; }
+ }
+
+ protected override async Task Run(IOpenSearchClient client)
+ {
+ // Temporarily configure the cluster to allow local running of the ML model
+ var putSettingsResp = await client.Cluster.PutSettingsAsync(s => s
+ .Transient(p => p
+ .Add("plugins.ml_commons.only_run_on_ml_node", false)
+ .Add("plugins.ml_commons.model_access_control_enabled", true)
+ .Add("plugins.ml_commons.native_memory_threshold", 99)));
+ Debug.Assert(putSettingsResp.IsValid, putSettingsResp.DebugInformation);
+ Console.WriteLine("Configured cluster to allow local execution of the ML model");
+
+ // Register an ML model group
+ var registerModelGroupResp = await client.Http.PostAsync(
+ "/_plugins/_ml/model_groups/_register",
+ r => r.SerializableBody(new
+ {
+ name = MlModelGroupName,
+ description = $"A model group for the opensearch-net {SampleName} sample",
+ access_mode = "public"
+ }));
+ Debug.Assert(registerModelGroupResp.Success && (string) registerModelGroupResp.Body.status == "CREATED", registerModelGroupResp.DebugInformation);
+ Console.WriteLine($"Model group named {MlModelGroupName} {registerModelGroupResp.Body.status}: {registerModelGroupResp.Body.model_group_id}");
+ _modelGroupId = (string) registerModelGroupResp.Body.model_group_id;
+
+ // Register the ML model
+ var registerModelResp = await client.Http.PostAsync(
+ "/_plugins/_ml/models/_register",
+ r => r.SerializableBody(new
+ {
+ name = "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b",
+ version = "1.0.1",
+ model_group_id = _modelGroupId,
+ model_format = "TORCH_SCRIPT"
+ }));
+ Debug.Assert(registerModelResp.Success && (string) registerModelResp.Body.status == "CREATED", registerModelResp.DebugInformation);
+ Console.WriteLine($"Model registration task {registerModelResp.Body.status}: {registerModelResp.Body.task_id}");
+ _modelRegistrationTaskId = (string) registerModelResp.Body.task_id;
+
+ // Wait for ML model registration to complete
+ while (true)
+ {
+ var getTaskResp = await client.Http.GetAsync($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}");
+ Console.WriteLine($"Model registration: {getTaskResp.Body.state}");
+ Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation);
+ if (((string)getTaskResp.Body.state).StartsWith("COMPLETED"))
+ {
+ _modelId = getTaskResp.Body.model_id;
+ break;
+ }
+ await Task.Delay(10000);
+ }
+ Console.WriteLine($"Model registered: {_modelId}");
+
+ // Deploy the ML model
+ var deployModelResp = await client.Http.PostAsync($"/_plugins/_ml/models/{_modelId}/_deploy");
+ Debug.Assert(deployModelResp.Success && (string) deployModelResp.Body.status == "CREATED", deployModelResp.DebugInformation);
+ Console.WriteLine($"Model deployment task {deployModelResp.Body.status}: {deployModelResp.Body.task_id}");
+ _modelDeployTaskId = (string) deployModelResp.Body.task_id;
+
+ // Wait for ML model deployment to complete
+ while (true)
+ {
+ var getTaskResp = await client.Http.GetAsync($"/_plugins/_ml/tasks/{_modelDeployTaskId}");
+ Console.WriteLine($"Model deployment: {getTaskResp.Body.state}");
+ Debug.Assert(getTaskResp.Success && (string) getTaskResp.Body.state != "FAILED", getTaskResp.DebugInformation);
+ if (((string)getTaskResp.Body.state).StartsWith("COMPLETED")) break;
+ await Task.Delay(10000);
+ }
+ Console.WriteLine($"Model deployed: {_modelId}");
+
+ // Create the text_embedding ingest pipeline
+ // TODO: Client does not yet contain typings for the text_embedding processor
+ var putIngestPipelineResp = await client.Http.PutAsync(
+ $"/_ingest/pipeline/{IngestPipelineName}",
+ r => r.SerializableBody(new
+ {
+ description = $"A text_embedding ingest pipeline for the opensearch-net {SampleName} sample",
+ processors = new[]
+ {
+ new
+ {
+ text_embedding = new
+ {
+ model_id = _modelId,
+ field_map = new
+ {
+ text = "passage_embedding"
+ }
+ }
+ }
+ }
+ }));
+ Debug.Assert(putIngestPipelineResp.IsValid, putIngestPipelineResp.DebugInformation);
+ Console.WriteLine($"Put ingest pipeline {IngestPipelineName}: {putIngestPipelineResp.Acknowledged}");
+ _putIngestPipeline = true;
+
+ // Create the index
+ var createIndexResp = await client.Indices.CreateAsync(
+ IndexName,
+ i => i
+ .Settings(s => s
+ .Setting("index.knn", true)
+ .DefaultPipeline(IngestPipelineName))
+ .Map(m => m
+ .Properties(p => p
+ .Text(t => t.Name(d => d.Id))
+ .Text(t => t.Name(d => d.Text))
+ .KnnVector(k => k
+ .Name(d => d.PassageEmbedding)
+ .Dimension(768)
+ .Method(km => km
+ .Engine("lucene")
+ .SpaceType("l2")
+ .Name("hnsw"))))));
+ Debug.Assert(createIndexResp.IsValid, createIndexResp.DebugInformation);
+ Console.WriteLine($"Created index {IndexName}: {createIndexResp.Acknowledged}");
+ _createdIndex = true;
+
+ // Index some documents
+ var documents = new NeuralSearchDoc[]
+ {
+ new() { Id = "4319130149.jpg", Text = "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena ." },
+ new() { Id = "1775029934.jpg", Text = "A wild animal races across an uncut field with a minimal amount of trees ." },
+ new() { Id = "2664027527.jpg", Text = "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco ." },
+ new() { Id = "4427058951.jpg", Text = "A man who is riding a wild horse in the rodeo is very near to falling off ." },
+ new() { Id = "2691147709.jpg", Text = "A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse ." }
+ };
+ var bulkResp = await client.BulkAsync(b => b
+ .Index(IndexName)
+ .IndexMany(documents)
+ .Refresh(Refresh.WaitFor));
+ Debug.Assert(bulkResp.IsValid, bulkResp.DebugInformation);
+ Console.WriteLine($"Indexed {documents.Length} documents");
+
+ // Perform the neural search
+ // TODO: Client does not yet contain typings for neural query type
+ Console.WriteLine("Performing neural search for text 'wild west'");
+ var searchResp = await client.Http.PostAsync>(
+ $"/{IndexName}/_search",
+ r => r.SerializableBody(new
+ {
+ _source = new { excludes = new[] { "passage_embedding" } },
+ query = new
+ {
+ neural = new
+ {
+ passage_embedding = new
+ {
+ query_text = "wild west",
+ model_id = _modelId,
+ k = 5
+ }
+ }
+ }
+ }));
+ Debug.Assert(searchResp.IsValid, searchResp.DebugInformation);
+ Console.WriteLine($"Found {searchResp.Hits.Count} documents");
+ foreach (var hit in searchResp.Hits) Console.WriteLine($"- Document id: {hit.Source.Id}, score: {hit.Score}, text: {hit.Source.Text}");
+ }
+
+ protected override async Task Cleanup(IOpenSearchClient client)
+ {
+ Console.WriteLine("\n\n-- CLEANING UP --");
+ if (_createdIndex)
+ {
+ // Cleanup the index
+ var deleteIndexResp = await client.Indices.DeleteAsync(IndexName);
+ Debug.Assert(deleteIndexResp.IsValid, deleteIndexResp.DebugInformation);
+ Console.WriteLine($"Deleted index: {deleteIndexResp.Acknowledged}");
+ }
+
+ if (_putIngestPipeline)
+ {
+ // Cleanup the ingest pipeline
+ var deleteIngestPipelineResp = await client.Ingest.DeletePipelineAsync(IngestPipelineName);
+ Debug.Assert(deleteIngestPipelineResp.IsValid, deleteIngestPipelineResp.DebugInformation);
+ Console.WriteLine($"Deleted ingest pipeline: {deleteIngestPipelineResp.Acknowledged}");
+ }
+
+ if (_modelDeployTaskId != null)
+ {
+ // Cleanup the model deployment task
+ var deleteModelDeployTaskResp = await client.Http.DeleteAsync($"/_plugins/_ml/tasks/{_modelDeployTaskId}");
+ Debug.Assert(deleteModelDeployTaskResp.Success && (string) deleteModelDeployTaskResp.Body.result == "deleted", deleteModelDeployTaskResp.DebugInformation);
+ Console.WriteLine($"Deleted model deployment task: {deleteModelDeployTaskResp.Body.result}");
+ }
+
+ if (_modelId != null)
+ {
+ while (true)
+ {
+ // Try cleanup the ML model
+ var deleteModelResp = await client.Http.DeleteAsync($"/_plugins/_ml/models/{_modelId}");
+ if (deleteModelResp.Success)
+ {
+ Console.WriteLine($"Deleted model: {deleteModelResp.Body.result}");
+ break;
+ }
+
+ Debug.Assert(((string?)deleteModelResp.Body.error?.reason)?.Contains("Try undeploy") ?? false, deleteModelResp.DebugInformation);
+
+ // Undeploy the ML model
+ var undeployModelResp = await client.Http.PostAsync($"/_plugins/_ml/models/{_modelId}/_undeploy");
+ Debug.Assert(undeployModelResp.Success, undeployModelResp.DebugInformation);
+ Console.WriteLine("Undeployed model");
+ await Task.Delay(10000);
+ }
+ }
+
+ if (_modelRegistrationTaskId != null)
+ {
+ // Cleanup the model registration task
+ var deleteModelRegistrationTaskResp = await client.Http.DeleteAsync($"/_plugins/_ml/tasks/{_modelRegistrationTaskId}");
+ Debug.Assert(deleteModelRegistrationTaskResp.Success && (string) deleteModelRegistrationTaskResp.Body.result == "deleted", deleteModelRegistrationTaskResp.DebugInformation);
+ Console.WriteLine($"Deleted model registration task: {deleteModelRegistrationTaskResp.Body.result}");
+ }
+
+ if (_modelGroupId != null)
+ {
+ // Cleanup the model group
+ var deleteModelGroupResp = await client.Http.DeleteAsync($"/_plugins/_ml/model_groups/{_modelGroupId}");
+ Debug.Assert(deleteModelGroupResp.Success && (string) deleteModelGroupResp.Body.result == "deleted", deleteModelGroupResp.DebugInformation);
+ Console.WriteLine($"Deleted model group: {deleteModelGroupResp.Body.result}");
+ }
+ }
+}
diff --git a/samples/Samples/Sample.cs b/samples/Samples/Sample.cs
index 6126255c99..e683b98935 100644
--- a/samples/Samples/Sample.cs
+++ b/samples/Samples/Sample.cs
@@ -33,10 +33,29 @@ public Command AsCommand(IValueDescriptor clientDescriptor)
{
var command = new Command(_name, _description);
- command.SetHandler(Run, clientDescriptor);
+ command.SetHandler(async client =>
+ {
+ try
+ {
+ await Run(client);
+ }
+ finally
+ {
+ try
+ {
+ await Cleanup(client);
+ }
+ catch (Exception e)
+ {
+ await Console.Error.WriteLineAsync($"Cleanup Failed: {e}");
+ }
+ }
+ }, clientDescriptor);
return command;
}
protected abstract Task Run(IOpenSearchClient client);
+
+ protected virtual Task Cleanup(IOpenSearchClient client) => Task.CompletedTask;
}