diff --git a/app/Directory.Build.props b/app/Directory.Build.props index bdc9812b..5009285c 100644 --- a/app/Directory.Build.props +++ b/app/Directory.Build.props @@ -3,7 +3,7 @@ 4.1.0 1.10.0 - 11.5.0-beta.2 + 11.5.0-beta.4 12.17.0 0.24.230918.1-preview 1.0.0-beta.7 diff --git a/app/backend/Extensions/SearchClientExtensions.cs b/app/backend/Extensions/SearchClientExtensions.cs index 14a4c05c..d1bed905 100644 --- a/app/backend/Extensions/SearchClientExtensions.cs +++ b/app/backend/Extensions/SearchClientExtensions.cs @@ -6,7 +6,8 @@ internal static class SearchClientExtensions { internal static async Task QueryDocumentsAsync( this SearchClient searchClient, - string query, + string? query = null, + float[]? embedding = null, RequestOverrides? overrides = null, CancellationToken cancellationToken = default) { @@ -34,6 +35,20 @@ internal static async Task QueryDocumentsAsync( Size = top, }; + if (embedding != null && overrides?.RetrievalMode != "Text") + { + var k = useSemanticRanker ? 50 : top; + var vectorQuery = new SearchQueryVector + { + // if semantic ranker is enabled, we need to set the rank to a large number to get more + // candidates for semantic reranking + KNearestNeighborsCount = useSemanticRanker ? 50 : top, + Value = embedding, + }; + vectorQuery.Fields.Add("embedding"); + searchOption.Vectors.Add(vectorQuery); + } + var searchResultResponse = await searchClient.SearchAsync(query, searchOption, cancellationToken); if (searchResultResponse.Value is null) { @@ -85,51 +100,4 @@ internal static async Task QueryDocumentsAsync( return documentContents; } - - internal static async Task LookupAsync( - this SearchClient searchClient, - string query, - RequestOverrides? overrides = null) - { - var option = new SearchOptions - { - Size = 1, - IncludeTotalCount = true, - QueryType = SearchQueryType.Semantic, - QueryLanguage = "en-us", - QuerySpeller = "lexicon", - SemanticConfigurationName = "default", - QueryAnswer = "extractive", - QueryCaption = "extractive", - }; - - var searchResultResponse = await searchClient.SearchAsync(query, option); - if (searchResultResponse.Value is null) - { - throw new InvalidOperationException("fail to get search result"); - } - - var searchResult = searchResultResponse.Value; - if (searchResult is { Answers.Count: > 0 }) - { - return searchResult.Answers[0].Text; - } - - if (searchResult.TotalCount > 0) - { - var contents = new List(); - await foreach (var doc in searchResult.GetResultsAsync()) - { - doc.Document.TryGetValue("content", out var contentValue); - if (contentValue is string content) - { - contents.Add(content); - } - } - - return string.Join("\n", contents); - } - - return string.Empty; - } } diff --git a/app/backend/Extensions/ServiceCollectionExtensions.cs b/app/backend/Extensions/ServiceCollectionExtensions.cs index 3a5c18ff..eed5b457 100644 --- a/app/backend/Extensions/ServiceCollectionExtensions.cs +++ b/app/backend/Extensions/ServiceCollectionExtensions.cs @@ -44,8 +44,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv services.AddSingleton(sp => { var config = sp.GetRequiredService(); - var azureOpenAiServiceEndpoint = config["AzureOpenAiServiceEndpoint"]; - ArgumentNullException.ThrowIfNullOrEmpty(azureOpenAiServiceEndpoint); + var azureOpenAiServiceEndpoint = config["AzureOpenAiServiceEndpoint"] ?? throw new ArgumentNullException(); var documentAnalysisClient = new DocumentAnalysisClient( new Uri(azureOpenAiServiceEndpoint), s_azureCredential); diff --git a/app/backend/Services/ReadRetrieveReadChatService.cs b/app/backend/Services/ReadRetrieveReadChatService.cs index 9a750590..6a4b544e 100644 --- a/app/backend/Services/ReadRetrieveReadChatService.cs +++ b/app/backend/Services/ReadRetrieveReadChatService.cs @@ -19,7 +19,14 @@ public ReadRetrieveReadChatService( { _searchClient = searchClient; var deployedModelName = configuration["AzureOpenAiChatGptDeployment"] ?? throw new ArgumentNullException(); - _kernel = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client).Build(); + var kernelBuilder = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client); + var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"]; + if (!string.IsNullOrEmpty(embeddingModelName)) + { + var endpoint = configuration["AzureOpenAiServiceEndpoint"] ?? throw new ArgumentNullException(); + kernelBuilder = kernelBuilder.WithAzureTextEmbeddingGenerationService(embeddingModelName, endpoint, new DefaultAzureCredential()); + } + _kernel = kernelBuilder.Build(); _configuration = configuration; } @@ -34,37 +41,49 @@ public async Task ReplyAsync( var excludeCategory = overrides?.ExcludeCategory ?? null; var filter = excludeCategory is null ? null : $"category ne '{excludeCategory}'"; IChatCompletion chat = _kernel.GetService(); + ITextEmbeddingGeneration? embedding = _kernel.GetService(); + float[]? embeddings = null; + var question = history.LastOrDefault()?.User is { } userQuestion + ? userQuestion + : throw new InvalidOperationException("Use question is null"); + if (overrides?.RetrievalMode != "Text" && embedding is not null) + { + embeddings = (await embedding.GenerateEmbeddingAsync(question)).ToArray(); + } // step 1 - // use llm to get query - var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question. + // use llm to get query if retrieval mode is not vector + string? query = null; + if (overrides?.RetrievalMode != "Vector") + { + var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question. Make your respond simple and precise. Return the query only, do not return any other text. e.g. Northwind Health Plus AND standard plan. standard plan AND dental AND employee benefit. "); - var question = history.LastOrDefault()?.User is { } userQuestion - ? userQuestion - : throw new InvalidOperationException("Use question is null"); - getQueryChat.AddUserMessage(question); - var result = await chat.GetChatCompletionsAsync( - getQueryChat, - new ChatRequestSettings + + getQueryChat.AddUserMessage(question); + var result = await chat.GetChatCompletionsAsync( + getQueryChat, + new ChatRequestSettings + { + Temperature = 0, + MaxTokens = 128, + }, + cancellationToken); + + if (result.Count != 1) { - Temperature = 0, - MaxTokens = 128, - }, - cancellationToken); + throw new InvalidOperationException("Failed to get search query"); + } - if (result.Count != 1) - { - throw new InvalidOperationException("Failed to get search query"); + query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content; } - - var query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content; + // step 2 // use query to search related docs - var documentContents = await _searchClient.QueryDocumentsAsync(query, overrides, cancellationToken); + var documentContents = await _searchClient.QueryDocumentsAsync(query, embeddings, overrides, cancellationToken); if (string.IsNullOrEmpty(documentContents)) { diff --git a/app/frontend/Components/SettingsPanel.razor b/app/frontend/Components/SettingsPanel.razor index 5d68327f..0145a820 100644 --- a/app/frontend/Components/SettingsPanel.razor +++ b/app/frontend/Components/SettingsPanel.razor @@ -8,22 +8,6 @@
- @if (_supportedSettings is not SupportedSettings.Chat) - { - Approach - - - Retrieve-Then-Read - - - Read-Retrieve-Read - - - Read-Decompose-Ask - - - } - @@ -41,6 +25,18 @@ + Retrieval Mode + + + Text + + + Hybrid + + + Vector + + diff --git a/app/prepdocs/PrepareDocs/AppOptions.cs b/app/prepdocs/PrepareDocs/AppOptions.cs index ecb72e28..b3e5fe75 100644 --- a/app/prepdocs/PrepareDocs/AppOptions.cs +++ b/app/prepdocs/PrepareDocs/AppOptions.cs @@ -10,7 +10,9 @@ internal record class AppOptions( string? Container, string? TenantId, string? SearchServiceEndpoint, + string? AzureOpenAIServiceEndpoint, string? SearchIndexName, + string? EmbeddingModelName, bool Remove, bool RemoveAll, string? FormRecognizerServiceEndpoint, diff --git a/app/prepdocs/PrepareDocs/GlobalUsings.cs b/app/prepdocs/PrepareDocs/GlobalUsings.cs index c521aeb3..d1ec4967 100644 --- a/app/prepdocs/PrepareDocs/GlobalUsings.cs +++ b/app/prepdocs/PrepareDocs/GlobalUsings.cs @@ -21,3 +21,4 @@ global using PdfSharpCore.Pdf; global using PdfSharpCore.Pdf.IO; global using PrepareDocs; +global using Azure.AI.OpenAI; diff --git a/app/prepdocs/PrepareDocs/PrepareDocs.csproj b/app/prepdocs/PrepareDocs/PrepareDocs.csproj index 37896762..b5188169 100644 --- a/app/prepdocs/PrepareDocs/PrepareDocs.csproj +++ b/app/prepdocs/PrepareDocs/PrepareDocs.csproj @@ -14,6 +14,7 @@ + diff --git a/app/prepdocs/PrepareDocs/Program.Clients.cs b/app/prepdocs/PrepareDocs/Program.Clients.cs index a54571b0..3e6b8616 100644 --- a/app/prepdocs/PrepareDocs/Program.Clients.cs +++ b/app/prepdocs/PrepareDocs/Program.Clients.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. + internal static partial class Program { private static BlobContainerClient? s_corpusContainerClient; @@ -7,12 +8,14 @@ internal static partial class Program private static DocumentAnalysisClient? s_documentClient; private static SearchIndexClient? s_searchIndexClient; private static SearchClient? s_searchClient; + private static OpenAIClient? s_openAIClient; private static readonly SemaphoreSlim s_corpusContainerLock = new(1); private static readonly SemaphoreSlim s_containerLock = new(1); private static readonly SemaphoreSlim s_documentLock = new(1); private static readonly SemaphoreSlim s_searchIndexLock = new(1); private static readonly SemaphoreSlim s_searchLock = new(1); + private static readonly SemaphoreSlim s_openAILock = new(1); private static Task GetCorpusBlobContainerClientAsync(AppOptions options) => GetLazyClientAsync(options, s_corpusContainerLock, static async o => @@ -119,6 +122,21 @@ private static Task GetSearchClientAsync(AppOptions options) => return s_searchClient; }); + private static Task GetAzureOpenAIClientAsync(AppOptions options) => + GetLazyClientAsync(options, s_openAILock, async o => + { + if (s_openAIClient is null) + { + var endpoint = o.AzureOpenAIServiceEndpoint; + ArgumentNullException.ThrowIfNullOrEmpty(endpoint); + s_openAIClient = new OpenAIClient( + new Uri(endpoint), + DefaultCredential); + } + await Task.CompletedTask; + return s_openAIClient; + }); + private static async Task GetLazyClientAsync( AppOptions options, SemaphoreSlim locker, diff --git a/app/prepdocs/PrepareDocs/Program.Options.cs b/app/prepdocs/PrepareDocs/Program.Options.cs index 7343a545..5de69f18 100644 --- a/app/prepdocs/PrepareDocs/Program.Options.cs +++ b/app/prepdocs/PrepareDocs/Program.Options.cs @@ -26,6 +26,12 @@ internal static partial class Program private static readonly Option s_searchIndexName = new(name: "--searchindex", description: "Name of the Azure Cognitive Search index where content should be indexed (will be created if it doesn't exist)"); + private static readonly Option s_azureOpenAIService = + new(name: "--openaiendpoint", description: "Optional. The Azure OpenAI service endpoint which will be used to extract text, tables and layout from the documents (must exist already)"); + + private static readonly Option s_embeddingModelName = + new(name: "--embeddingmodel", description: "Optional. Name of the Azure Cognitive Search embedding model to use for embedding content in the search index (will be created if it doesn't exist)"); + private static readonly Option s_remove = new(name: "--remove", description: "Remove references to this document from blob storage and the search index"); @@ -45,7 +51,7 @@ internal static partial class Program """) { s_files, s_category, s_skipBlobs, s_storageEndpoint, - s_container, s_tenantId, s_searchService, s_searchIndexName, + s_container, s_tenantId, s_searchService, s_searchIndexName, s_azureOpenAIService, s_embeddingModelName, s_remove, s_removeAll, s_formRecognizerServiceEndpoint, s_verbose }; @@ -58,6 +64,8 @@ internal static partial class Program TenantId: context.ParseResult.GetValueForOption(s_tenantId), SearchServiceEndpoint: context.ParseResult.GetValueForOption(s_searchService), SearchIndexName: context.ParseResult.GetValueForOption(s_searchIndexName), + AzureOpenAIServiceEndpoint: context.ParseResult.GetValueForOption(s_azureOpenAIService), + EmbeddingModelName: context.ParseResult.GetValueForOption(s_embeddingModelName), Remove: context.ParseResult.GetValueForOption(s_remove), RemoveAll: context.ParseResult.GetValueForOption(s_removeAll), FormRecognizerServiceEndpoint: context.ParseResult.GetValueForOption(s_formRecognizerServiceEndpoint), diff --git a/app/prepdocs/PrepareDocs/Program.cs b/app/prepdocs/PrepareDocs/Program.cs index 047f097b..896141f4 100644 --- a/app/prepdocs/PrepareDocs/Program.cs +++ b/app/prepdocs/PrepareDocs/Program.cs @@ -144,6 +144,10 @@ Removing sections from '{fileName ?? "all"}' from search index '{options.SearchI }); } + if (documentsToDelete.Count == 0) + { + break; + } Response deleteResponse = await searchClient.DeleteDocumentsAsync(documentsToDelete); @@ -176,15 +180,30 @@ static async ValueTask CreateSearchIndexAsync(AppOptions options) } } + string vectorSearchConfigName = "my-vector-config"; + var index = new SearchIndex(options.SearchIndexName) { + VectorSearch = new() + { + AlgorithmConfigurations = + { + new HnswVectorSearchAlgorithmConfiguration(vectorSearchConfigName) + } + }, Fields = { new SimpleField("id", SearchFieldDataType.String) { IsKey = true }, new SearchableField("content") { AnalyzerName = "en.microsoft" }, new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true }, new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true }, - new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true } + new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true }, + new SearchField("embedding", SearchFieldDataType.Collection(SearchFieldDataType.Single)) + { + VectorSearchDimensions = 1536, + IsSearchable = true, + VectorSearchConfiguration = vectorSearchConfigName, + } }, SemanticSettings = new SemanticSettings { @@ -502,11 +521,14 @@ Indexing sections from '{fileName}' into search index '{options.SearchIndexName} } var searchClient = await GetSearchClientAsync(options); + var openAIClient = await GetAzureOpenAIClientAsync(options); var iteration = 0; var batch = new IndexDocumentsBatch(); foreach (var section in sections) { + var embeddings = await openAIClient.GetEmbeddingsAsync(options.EmbeddingModelName, new Azure.AI.OpenAI.EmbeddingsOptions(section.Content.Replace('\r', ' '))); + var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? new float[0]; batch.Actions.Add(new IndexDocumentsAction( IndexActionType.MergeOrUpload, new SearchDocument @@ -515,7 +537,8 @@ Indexing sections from '{fileName}' into search index '{options.SearchIndexName} ["content"] = section.Content, ["category"] = section.Category, ["sourcepage"] = section.SourcePage, - ["sourcefile"] = section.SourceFile + ["sourcefile"] = section.SourceFile, + ["embedding"] = embedding, })); iteration++; diff --git a/app/prepdocs/PrepareDocs/Properties/launchSettings.json b/app/prepdocs/PrepareDocs/Properties/launchSettings.json index fe16cee5..48ee9474 100644 --- a/app/prepdocs/PrepareDocs/Properties/launchSettings.json +++ b/app/prepdocs/PrepareDocs/Properties/launchSettings.json @@ -2,7 +2,7 @@ "profiles": { "PrepareDocs": { "commandName": "Project", - "commandLineArgs": "../../../../../../data/*.pdf\r\n--storageendpoint %AZURE_STORAGE_BLOB_ENDPOINT%\r\n--container %AZURE_STORAGE_CONTAINER%\r\n--searchendpoint %AZURE_SEARCH_SERVICE_ENDPOINT%\r\n--searchindex %AZURE_SEARCH_INDEX%\r\n--formrecognizerendpoint %AZURE_FORMRECOGNIZER_SERVICE_ENDPOINT%\r\n--tenantid %AZURE_TENANT_ID%\r\n-v" + "commandLineArgs": "../../../../../../data/*.pdf\r\n--storageendpoint %AZURE_STORAGE_BLOB_ENDPOINT%\r\n--container %AZURE_STORAGE_CONTAINER%\r\n--searchendpoint %AZURE_SEARCH_SERVICE_ENDPOINT%\r\n--searchindex %AZURE_SEARCH_INDEX%\r\n--formrecognizerendpoint %AZURE_FORMRECOGNIZER_SERVICE_ENDPOINT%\r\n--tenantid %AZURE_TENANT_ID%\r\n--openaiendpoint %AZURE_OPENAI_ENDPOINT%\r\n--embeddingmodel %AZURE_OPENAI_EMBEDDING_DEPLOYMENT%\r\n-v" } } } \ No newline at end of file diff --git a/app/shared/Shared/Models/RequestOverrides.cs b/app/shared/Shared/Models/RequestOverrides.cs index 059bb2dd..b811283e 100644 --- a/app/shared/Shared/Models/RequestOverrides.cs +++ b/app/shared/Shared/Models/RequestOverrides.cs @@ -5,6 +5,9 @@ namespace Shared.Models; public record RequestOverrides { public bool SemanticRanker { get; set; } = false; + + public string RetrievalMode { get; set; } = "Vector"; // available option: Text, Vector, Hybrid + public bool? SemanticCaptions { get; set; } public string? ExcludeCategory { get; set; } public int? Top { get; set; } = 3; diff --git a/infra/app/web.bicep b/infra/app/web.bicep index 22e5829d..fba3c664 100644 --- a/infra/app/web.bicep +++ b/infra/app/web.bicep @@ -47,12 +47,12 @@ param formRecognizerEndpoint string @description('The OpenAI endpoint') param openAiEndpoint string -@description('The OpenAI GPT deployment name') -param openAiGptDeployment string - @description('The OpenAI ChatGPT deployment name') param openAiChatGptDeployment string +@description('The OpenAI Embedding deployment name') +param openAiEmbeddingDeployment string + @description('An array of service binds') param serviceBinds array @@ -120,14 +120,14 @@ module app '../core/host/container-app-upsert.bicep' = { name: 'AZURE_OPENAI_ENDPOINT' value: openAiEndpoint } - { - name: 'AZURE_OPENAI_GPT_DEPLOYMENT' - value: openAiGptDeployment - } { name: 'AZURE_OPENAI_CHATGPT_DEPLOYMENT' value: openAiChatGptDeployment } + { + name: 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT' + value: openAiEmbeddingDeployment + } ] targetPort: 80 } diff --git a/infra/main.bicep b/infra/main.bicep index d0530ebf..3a9e1ed8 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -24,6 +24,15 @@ param chatGptDeploymentName string = 'chat' @description('Name of the chat GPT model. Default: gpt-35-turbo') param chatGptModelName string = 'gpt-35-turbo' +@description('Name of the embedding deployment. Default: embedding') +param embeddingDeploymentName string = 'embedding' + +@description('Capacity of the embedding deployment. Default: 30') +param embeddingDeploymentCapacity int = 30 + +@description('Name of the embedding model. Default: text-embedding-ada-002') +param embeddingModelName string = 'text-embedding-ada-002' + @description('Name of the container apps environment') param containerAppsEnvironmentName string = '' @@ -177,6 +186,10 @@ module keyVaultSecrets 'core/security/keyvault-secrets.bicep' = { name: 'AzureOpenAiChatGptDeployment' value: chatGptDeploymentName } + { + name: 'AzureOpenAiEmbeddingDeployment' + value: embeddingDeploymentName + } { name: 'AzureSearchServiceEndpoint' value: searchService.outputs.endpoint @@ -234,6 +247,7 @@ module web './app/web.bicep' = { formRecognizerEndpoint: formRecognizer.outputs.endpoint openAiEndpoint: openAi.outputs.endpoint openAiChatGptDeployment: chatGptDeploymentName + openAiEmbeddingDeployment: embeddingDeploymentName serviceBinds: [ redis.outputs.serviceBind ] } } @@ -289,6 +303,18 @@ module openAi 'core/ai/cognitiveservices.bicep' = { capacity: chatGptDeploymentCapacity } } + { + name: embeddingDeploymentName + model: { + format: 'OpenAI' + name: embeddingModelName + version: '2' + } + sku: { + name: 'Standard' + capacity: embeddingDeploymentCapacity + } + } ] } } @@ -466,6 +492,7 @@ output AZURE_KEY_VAULT_NAME string = keyVault.outputs.name output AZURE_KEY_VAULT_RESOURCE_GROUP string = keyVaultResourceGroup.name output AZURE_LOCATION string = location output AZURE_OPENAI_CHATGPT_DEPLOYMENT string = chatGptDeploymentName +output AZURE_OPENAI_EMBEDDING_DEPLOYMENT string = embeddingDeploymentName output AZURE_OPENAI_ENDPOINT string = openAi.outputs.endpoint output AZURE_OPENAI_RESOURCE_GROUP string = openAiResourceGroup.name output AZURE_OPENAI_SERVICE string = openAi.outputs.name diff --git a/scripts/prepdocs.ps1 b/scripts/prepdocs.ps1 index 3de66432..81db5c86 100644 --- a/scripts/prepdocs.ps1 +++ b/scripts/prepdocs.ps1 @@ -23,6 +23,8 @@ if ([string]::IsNullOrEmpty($env:AZD_PREPDOCS_RAN) -or $env:AZD_PREPDOCS_RAN -eq --container $env:AZURE_STORAGE_CONTAINER ` --searchendpoint $env:AZURE_SEARCH_SERVICE_ENDPOINT ` --searchindex $env:AZURE_SEARCH_INDEX ` + --openaiendpoint $env:AZURE_OPENAI_ENDPOINT ` + --embeddingmodel $env:AZURE_OPENAI_EMBEDDING_DEPLOYMENT ` --formrecognizerendpoint $env:AZURE_FORMRECOGNIZER_SERVICE_ENDPOINT ` --tenantid $env:AZURE_TENANT_ID ` -v