Skip to content

Commit

Permalink
enable embedding search (#164)
Browse files Browse the repository at this point in the history
## Purpose
<!-- Describe the intention of the changes being proposed. What problem
does it solve or functionality does it add? -->
* ...

## Does this introduce a breaking change?
<!-- Mark one with an "x". -->
```
[ ] Yes
[ ] No
```
fix #120

This PR enables option for three search mode
- Text
- Vector
- Hybrid

When `Vector` or `Hybrid` mode is enabled, vector search will be enabled
when searching document from index.

<img width="443" alt="image"
src="https://github.com/Azure-Samples/azure-search-openai-demo-csharp/assets/16876986/6326ce20-bed7-49f5-aae0-e62f2505dcdd">


## Pull Request Type
What kind of change does this Pull Request introduce?

<!-- Please check the one that applies to this PR using "x". -->
```
[ ] Bugfix
[ ] Feature
[ ] Code style update (formatting, local variables)
[ ] Refactoring (no functional changes, no api changes)
[ ] Documentation content changes
[ ] Other... Please describe:
```

## How to Test
*  Get the code

```
git clone [repo-address]
cd [repo-name]
git checkout [branch-name]
npm install
```

* Test the code
<!-- Add steps to run the tests suite and/or manually test -->
```
```

## What to Check
Verify that the following are valid
* ...

## Other Information
<!-- Add any other helpful information that may be needed here. -->
  • Loading branch information
LittleLittleCloud authored Sep 28, 2023
1 parent ace365f commit d74cef4
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 98 deletions.
2 changes: 1 addition & 1 deletion app/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<PropertyGroup>
<AzureFormRecognizerVersion>4.1.0</AzureFormRecognizerVersion>
<AzureIdentityVersion>1.10.0</AzureIdentityVersion>
<AzureSearchDocumentsVersion>11.5.0-beta.2</AzureSearchDocumentsVersion>
<AzureSearchDocumentsVersion>11.5.0-beta.4</AzureSearchDocumentsVersion>
<AzureStorageBlobsVersion>12.17.0</AzureStorageBlobsVersion>
<SemanticKernelVersion>0.24.230918.1-preview</SemanticKernelVersion>
<AzureOpenAIVersion>1.0.0-beta.7</AzureOpenAIVersion>
Expand Down
64 changes: 16 additions & 48 deletions app/backend/Extensions/SearchClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ internal static class SearchClientExtensions
{
internal static async Task<string> QueryDocumentsAsync(
this SearchClient searchClient,
string query,
string? query = null,
float[]? embedding = null,
RequestOverrides? overrides = null,
CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -34,6 +35,20 @@ internal static async Task<string> QueryDocumentsAsync(
Size = top,
};

if (embedding != null && overrides?.RetrievalMode != "Text")
{
var k = useSemanticRanker ? 50 : top;
var vectorQuery = new SearchQueryVector
{
// if semantic ranker is enabled, we need to set the rank to a large number to get more
// candidates for semantic reranking
KNearestNeighborsCount = useSemanticRanker ? 50 : top,
Value = embedding,
};
vectorQuery.Fields.Add("embedding");
searchOption.Vectors.Add(vectorQuery);
}

var searchResultResponse = await searchClient.SearchAsync<SearchDocument>(query, searchOption, cancellationToken);
if (searchResultResponse.Value is null)
{
Expand Down Expand Up @@ -85,51 +100,4 @@ internal static async Task<string> QueryDocumentsAsync(

return documentContents;
}

internal static async Task<string> LookupAsync(
this SearchClient searchClient,
string query,
RequestOverrides? overrides = null)
{
var option = new SearchOptions
{
Size = 1,
IncludeTotalCount = true,
QueryType = SearchQueryType.Semantic,
QueryLanguage = "en-us",
QuerySpeller = "lexicon",
SemanticConfigurationName = "default",
QueryAnswer = "extractive",
QueryCaption = "extractive",
};

var searchResultResponse = await searchClient.SearchAsync<SearchDocument>(query, option);
if (searchResultResponse.Value is null)
{
throw new InvalidOperationException("fail to get search result");
}

var searchResult = searchResultResponse.Value;
if (searchResult is { Answers.Count: > 0 })
{
return searchResult.Answers[0].Text;
}

if (searchResult.TotalCount > 0)
{
var contents = new List<string>();
await foreach (var doc in searchResult.GetResultsAsync())
{
doc.Document.TryGetValue("content", out var contentValue);
if (contentValue is string content)
{
contents.Add(content);
}
}

return string.Join("\n", contents);
}

return string.Empty;
}
}
3 changes: 1 addition & 2 deletions app/backend/Extensions/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
services.AddSingleton<DocumentAnalysisClient>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var azureOpenAiServiceEndpoint = config["AzureOpenAiServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrEmpty(azureOpenAiServiceEndpoint);
var azureOpenAiServiceEndpoint = config["AzureOpenAiServiceEndpoint"] ?? throw new ArgumentNullException();

var documentAnalysisClient = new DocumentAnalysisClient(
new Uri(azureOpenAiServiceEndpoint), s_azureCredential);
Expand Down
59 changes: 39 additions & 20 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ public ReadRetrieveReadChatService(
{
_searchClient = searchClient;
var deployedModelName = configuration["AzureOpenAiChatGptDeployment"] ?? throw new ArgumentNullException();
_kernel = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client).Build();
var kernelBuilder = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client);
var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"];
if (!string.IsNullOrEmpty(embeddingModelName))
{
var endpoint = configuration["AzureOpenAiServiceEndpoint"] ?? throw new ArgumentNullException();
kernelBuilder = kernelBuilder.WithAzureTextEmbeddingGenerationService(embeddingModelName, endpoint, new DefaultAzureCredential());
}
_kernel = kernelBuilder.Build();
_configuration = configuration;
}

Expand All @@ -34,37 +41,49 @@ public async Task<ApproachResponse> ReplyAsync(
var excludeCategory = overrides?.ExcludeCategory ?? null;
var filter = excludeCategory is null ? null : $"category ne '{excludeCategory}'";
IChatCompletion chat = _kernel.GetService<IChatCompletion>();
ITextEmbeddingGeneration? embedding = _kernel.GetService<ITextEmbeddingGeneration>();
float[]? embeddings = null;
var question = history.LastOrDefault()?.User is { } userQuestion
? userQuestion
: throw new InvalidOperationException("Use question is null");
if (overrides?.RetrievalMode != "Text" && embedding is not null)
{
embeddings = (await embedding.GenerateEmbeddingAsync(question)).ToArray();
}

// step 1
// use llm to get query
var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question.
// use llm to get query if retrieval mode is not vector
string? query = null;
if (overrides?.RetrievalMode != "Vector")
{
var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question.
Make your respond simple and precise. Return the query only, do not return any other text.
e.g.
Northwind Health Plus AND standard plan.
standard plan AND dental AND employee benefit.
");
var question = history.LastOrDefault()?.User is { } userQuestion
? userQuestion
: throw new InvalidOperationException("Use question is null");
getQueryChat.AddUserMessage(question);
var result = await chat.GetChatCompletionsAsync(
getQueryChat,
new ChatRequestSettings

getQueryChat.AddUserMessage(question);
var result = await chat.GetChatCompletionsAsync(
getQueryChat,
new ChatRequestSettings
{
Temperature = 0,
MaxTokens = 128,
},
cancellationToken);

if (result.Count != 1)
{
Temperature = 0,
MaxTokens = 128,
},
cancellationToken);
throw new InvalidOperationException("Failed to get search query");
}

if (result.Count != 1)
{
throw new InvalidOperationException("Failed to get search query");
query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content;
}

var query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content;

// step 2
// use query to search related docs
var documentContents = await _searchClient.QueryDocumentsAsync(query, overrides, cancellationToken);
var documentContents = await _searchClient.QueryDocumentsAsync(query, embeddings, overrides, cancellationToken);

if (string.IsNullOrEmpty(documentContents))
{
Expand Down
28 changes: 12 additions & 16 deletions app/frontend/Components/SettingsPanel.razor
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,6 @@
</MudText>
</MudDrawerHeader>
<div class="pa-6">
@if (_supportedSettings is not SupportedSettings.Chat)
{
<MudText Typo="Typo.subtitle1">Approach</MudText>
<MudRadioGroup @bind-SelectedOption="@Settings.Approach" Class="pa-2">
<MudRadio Option="@Approach.RetrieveThenRead" Color="Color.Primary">
Retrieve-Then-Read
</MudRadio>
<MudRadio Option="@Approach.ReadRetrieveRead" Color="Color.Primary">
Read-Retrieve-Read
</MudRadio>
<MudRadio Option="@Approach.ReadDecomposeAsk" Color="Color.Primary">
Read-Decompose-Ask
</MudRadio>
</MudRadioGroup>
}

<MudTextField T="string" Lines="5" Variant="Variant.Outlined"
Label="Override prompt template" Placeholder="Override prompt template" Class="pa-2"
@bind-Value="Settings.Overrides.PromptTemplate" Clearable="true" />
Expand All @@ -41,6 +25,18 @@
<MudCheckBox @bind-Checked="@Settings.Overrides.SemanticRanker" Size="Size.Large"
Color="Color.Primary" Label="Use semantic ranker for retrieval" />

<MudText Typo="Typo.subtitle1">Retrieval Mode</MudText>
<MudRadioGroup Required="true" @bind-SelectedOption="@Settings.Overrides.RetrievalMode" Class="pa-2">
<MudRadio T="string" Option="@("Text")" Color="Color.Primary">
Text
</MudRadio>
<MudRadio T="string" Option="@("Hybrid")" Color="Color.Primary">
Hybrid
</MudRadio>
<MudRadio T="string" Option="@("Vector")" Color="Color.Primary">
Vector
</MudRadio>
</MudRadioGroup>
<MudCheckBox @bind-Checked="@Settings.Overrides.SemanticCaptions" Size="Size.Large"
Color="Color.Primary"
Label="Use query-contrextual summaries instead of whole documents" />
Expand Down
2 changes: 2 additions & 0 deletions app/prepdocs/PrepareDocs/AppOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ internal record class AppOptions(
string? Container,
string? TenantId,
string? SearchServiceEndpoint,
string? AzureOpenAIServiceEndpoint,
string? SearchIndexName,
string? EmbeddingModelName,
bool Remove,
bool RemoveAll,
string? FormRecognizerServiceEndpoint,
Expand Down
1 change: 1 addition & 0 deletions app/prepdocs/PrepareDocs/GlobalUsings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
global using PdfSharpCore.Pdf;
global using PdfSharpCore.Pdf.IO;
global using PrepareDocs;
global using Azure.AI.OpenAI;
1 change: 1 addition & 0 deletions app/prepdocs/PrepareDocs/PrepareDocs.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<PackageReference Include="Azure.Storage.Blobs" Version="$(AzureStorageBlobsVersion)" />
<PackageReference Include="Microsoft.Extensions.FileSystemGlobbing" Version="7.0.0" />
<PackageReference Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
<PackageReference Include="Azure.AI.OpenAI" Version="$(AzureOpenAIVersion)" />
<PackageReference Include="PdfSharpCore" Version="1.3.60" />
</ItemGroup>

Expand Down
18 changes: 18 additions & 0 deletions app/prepdocs/PrepareDocs/Program.Clients.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
// Copyright (c) Microsoft. All rights reserved.


internal static partial class Program
{
private static BlobContainerClient? s_corpusContainerClient;
private static BlobContainerClient? s_containerClient;
private static DocumentAnalysisClient? s_documentClient;
private static SearchIndexClient? s_searchIndexClient;
private static SearchClient? s_searchClient;
private static OpenAIClient? s_openAIClient;

private static readonly SemaphoreSlim s_corpusContainerLock = new(1);
private static readonly SemaphoreSlim s_containerLock = new(1);
private static readonly SemaphoreSlim s_documentLock = new(1);
private static readonly SemaphoreSlim s_searchIndexLock = new(1);
private static readonly SemaphoreSlim s_searchLock = new(1);
private static readonly SemaphoreSlim s_openAILock = new(1);

private static Task<BlobContainerClient> GetCorpusBlobContainerClientAsync(AppOptions options) =>
GetLazyClientAsync<BlobContainerClient>(options, s_corpusContainerLock, static async o =>
Expand Down Expand Up @@ -119,6 +122,21 @@ private static Task<SearchClient> GetSearchClientAsync(AppOptions options) =>
return s_searchClient;
});

private static Task<OpenAIClient> GetAzureOpenAIClientAsync(AppOptions options) =>
GetLazyClientAsync<OpenAIClient>(options, s_openAILock, async o =>
{
if (s_openAIClient is null)
{
var endpoint = o.AzureOpenAIServiceEndpoint;
ArgumentNullException.ThrowIfNullOrEmpty(endpoint);
s_openAIClient = new OpenAIClient(
new Uri(endpoint),
DefaultCredential);
}
await Task.CompletedTask;
return s_openAIClient;
});

private static async Task<TClient> GetLazyClientAsync<TClient>(
AppOptions options,
SemaphoreSlim locker,
Expand Down
10 changes: 9 additions & 1 deletion app/prepdocs/PrepareDocs/Program.Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ internal static partial class Program
private static readonly Option<string> s_searchIndexName =
new(name: "--searchindex", description: "Name of the Azure Cognitive Search index where content should be indexed (will be created if it doesn't exist)");

private static readonly Option<string> s_azureOpenAIService =
new(name: "--openaiendpoint", description: "Optional. The Azure OpenAI service endpoint which will be used to extract text, tables and layout from the documents (must exist already)");

private static readonly Option<string> s_embeddingModelName =
new(name: "--embeddingmodel", description: "Optional. Name of the Azure Cognitive Search embedding model to use for embedding content in the search index (will be created if it doesn't exist)");

private static readonly Option<bool> s_remove =
new(name: "--remove", description: "Remove references to this document from blob storage and the search index");

Expand All @@ -45,7 +51,7 @@ internal static partial class Program
""")
{
s_files, s_category, s_skipBlobs, s_storageEndpoint,
s_container, s_tenantId, s_searchService, s_searchIndexName,
s_container, s_tenantId, s_searchService, s_searchIndexName, s_azureOpenAIService, s_embeddingModelName,
s_remove, s_removeAll, s_formRecognizerServiceEndpoint, s_verbose
};

Expand All @@ -58,6 +64,8 @@ internal static partial class Program
TenantId: context.ParseResult.GetValueForOption(s_tenantId),
SearchServiceEndpoint: context.ParseResult.GetValueForOption(s_searchService),
SearchIndexName: context.ParseResult.GetValueForOption(s_searchIndexName),
AzureOpenAIServiceEndpoint: context.ParseResult.GetValueForOption(s_azureOpenAIService),
EmbeddingModelName: context.ParseResult.GetValueForOption(s_embeddingModelName),
Remove: context.ParseResult.GetValueForOption(s_remove),
RemoveAll: context.ParseResult.GetValueForOption(s_removeAll),
FormRecognizerServiceEndpoint: context.ParseResult.GetValueForOption(s_formRecognizerServiceEndpoint),
Expand Down
27 changes: 25 additions & 2 deletions app/prepdocs/PrepareDocs/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ Removing sections from '{fileName ?? "all"}' from search index '{options.SearchI
});
}

if (documentsToDelete.Count == 0)
{
break;
}
Response<IndexDocumentsResult> deleteResponse =
await searchClient.DeleteDocumentsAsync(documentsToDelete);

Expand Down Expand Up @@ -176,15 +180,30 @@ static async ValueTask CreateSearchIndexAsync(AppOptions options)
}
}

string vectorSearchConfigName = "my-vector-config";

var index = new SearchIndex(options.SearchIndexName)
{
VectorSearch = new()
{
AlgorithmConfigurations =
{
new HnswVectorSearchAlgorithmConfiguration(vectorSearchConfigName)
}
},
Fields =
{
new SimpleField("id", SearchFieldDataType.String) { IsKey = true },
new SearchableField("content") { AnalyzerName = "en.microsoft" },
new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true }
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true },
new SearchField("embedding", SearchFieldDataType.Collection(SearchFieldDataType.Single))
{
VectorSearchDimensions = 1536,
IsSearchable = true,
VectorSearchConfiguration = vectorSearchConfigName,
}
},
SemanticSettings = new SemanticSettings
{
Expand Down Expand Up @@ -502,11 +521,14 @@ Indexing sections from '{fileName}' into search index '{options.SearchIndexName}
}

var searchClient = await GetSearchClientAsync(options);
var openAIClient = await GetAzureOpenAIClientAsync(options);

var iteration = 0;
var batch = new IndexDocumentsBatch<SearchDocument>();
foreach (var section in sections)
{
var embeddings = await openAIClient.GetEmbeddingsAsync(options.EmbeddingModelName, new Azure.AI.OpenAI.EmbeddingsOptions(section.Content.Replace('\r', ' ')));
var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? new float[0];
batch.Actions.Add(new IndexDocumentsAction<SearchDocument>(
IndexActionType.MergeOrUpload,
new SearchDocument
Expand All @@ -515,7 +537,8 @@ Indexing sections from '{fileName}' into search index '{options.SearchIndexName}
["content"] = section.Content,
["category"] = section.Category,
["sourcepage"] = section.SourcePage,
["sourcefile"] = section.SourceFile
["sourcefile"] = section.SourceFile,
["embedding"] = embedding,
}));

iteration++;
Expand Down
Loading

0 comments on commit d74cef4

Please sign in to comment.