diff --git a/app/Directory.Build.props b/app/Directory.Build.props index 92ed207b..bdc9812b 100644 --- a/app/Directory.Build.props +++ b/app/Directory.Build.props @@ -5,8 +5,8 @@ 1.10.0 11.5.0-beta.2 12.17.0 - 0.13.277.1-preview - 1.0.0-beta.6 + 0.24.230918.1-preview + 1.0.0-beta.7 7.0.10 7.0.10 2.0.1 diff --git a/app/backend/Extensions/ServiceCollectionExtensions.cs b/app/backend/Extensions/ServiceCollectionExtensions.cs index b1e546e5..3a5c18ff 100644 --- a/app/backend/Extensions/ServiceCollectionExtensions.cs +++ b/app/backend/Extensions/ServiceCollectionExtensions.cs @@ -66,8 +66,6 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv }); services.AddSingleton(); - - services.AddSingleton(); services.AddSingleton(); return services; @@ -85,6 +83,4 @@ internal static IServiceCollection AddCrossOriginResourceSharing(this IServiceCo return services; } - - internal static IServiceCollection AddMemoryStore(this IServiceCollection services) => services.AddSingleton(); } diff --git a/app/backend/Services/AzureOpenAIChatCompletionService.cs b/app/backend/Services/AzureOpenAIChatCompletionService.cs deleted file mode 100644 index 77a8c6cf..00000000 --- a/app/backend/Services/AzureOpenAIChatCompletionService.cs +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace MinimalApi.Services; - -public sealed class AzureOpenAIChatCompletionService : ITextCompletion -{ - private readonly OpenAIClient _openAIClient; - private readonly string _deployedModelName; - - public AzureOpenAIChatCompletionService(OpenAIClient openAIClient, IConfiguration config) - { - _openAIClient = openAIClient; - - var deployedModelName = config["AzureOpenAiChatGptDeployment"]; - ArgumentNullException.ThrowIfNullOrEmpty(deployedModelName); - _deployedModelName = deployedModelName; - } - - - public async Task CompleteAsync( - string text, CompleteRequestSettings requestSettings, CancellationToken cancellationToken = default) - { - var option = new CompletionsOptions - { - MaxTokens = requestSettings.MaxTokens, - FrequencyPenalty = Convert.ToSingle(requestSettings.FrequencyPenalty), - PresencePenalty = Convert.ToSingle(requestSettings.PresencePenalty), - Temperature = Convert.ToSingle(requestSettings.Temperature), - Prompts = { text }, - }; - - foreach (var stopSequence in requestSettings.StopSequences) - { - option.StopSequences.Add(stopSequence); - } - - var response = - await _openAIClient.GetCompletionsAsync( - _deployedModelName, option, cancellationToken); - return response.Value is Completions completions && completions.Choices.Count > 0 - ? completions.Choices[0].Text - : throw new AIException(AIException.ErrorCodes.InvalidConfiguration, "completion not found"); - } -} diff --git a/app/backend/Services/CorpusMemoryStore.cs b/app/backend/Services/CorpusMemoryStore.cs deleted file mode 100644 index fc279bb9..00000000 --- a/app/backend/Services/CorpusMemoryStore.cs +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace MinimalApi.Services; - -public sealed class CorpusMemoryStore : IMemoryStore -{ - private readonly ILogger _logger; - private readonly BlobServiceClient _blobServiceClient; - private readonly IMemoryStore _store = new VolatileMemoryStore(); - - // TODO: Consider using the StringBuilderObjectPool approach for reusing builders in tight loops. - // https://learn.microsoft.com/aspnet/core/performance/objectpool?view=aspnetcore-7.0 - public CorpusMemoryStore(BlobServiceClient blobServiceClient, ILogger logger) => (_blobServiceClient, _logger) = (blobServiceClient, logger); - - public Task CreateCollectionAsync( - string collectionName, - CancellationToken cancel = default) => _store.CreateCollectionAsync(collectionName, cancel); - - public Task DeleteCollectionAsync( - string collectionName, - CancellationToken cancel = default) => _store.DeleteCollectionAsync(collectionName, cancel); - - public Task DoesCollectionExistAsync( - string collectionName, - CancellationToken cancel = default) => _store.DoesCollectionExistAsync(collectionName, cancel); - - public Task GetAsync( - string collectionName, - string key, - bool withEmbedding = false, - CancellationToken cancel = default) => _store.GetAsync(collectionName, key, withEmbedding, cancel); - - public IAsyncEnumerable GetBatchAsync( - string collectionName, - IEnumerable keys, - bool withEmbeddings = false, - CancellationToken cancel = default) => _store.GetBatchAsync(collectionName, keys, withEmbeddings, cancel); - - public IAsyncEnumerable GetCollectionsAsync(CancellationToken cancel = default) => _store.GetCollectionsAsync(cancel); - - public Task<(MemoryRecord, double)?> GetNearestMatchAsync( - string collectionName, - Embedding embedding, - double minRelevanceScore = 0, - bool withEmbedding = false, - CancellationToken cancel = default) => _store.GetNearestMatchAsync(collectionName, embedding, minRelevanceScore, withEmbedding, cancel); - - public IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( - string collectionName, - Embedding embedding, - int limit, - double minRelevanceScore = 0, - bool withEmbeddings = false, - CancellationToken cancel = default) => _store.GetNearestMatchesAsync(collectionName, embedding, limit, minRelevanceScore, withEmbeddings, cancel); - - public async Task InitializeAsync() - { - _logger.LogInformation("Loading corpus ..."); - - var blobContainerClient = _blobServiceClient.GetBlobContainerClient("corpus"); - var corpus = new List(); - await foreach (var blob in blobContainerClient.GetBlobsAsync()) - { - var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(blob.Name); - var source = $"{fileNameWithoutExtension}.pdf"; - using var readStream = blobContainerClient.GetBlobClient(blob.Name).OpenRead(); - using var reader = new StreamReader(readStream); - var content = await reader.ReadToEndAsync(); - - // Split contents into short sentences - var sentences = content.Split(new[] { '.', '?', '!' }, StringSplitOptions.RemoveEmptyEntries); - var corpusIndex = 0; - var sb = new StringBuilder(); - - // Create corpus records based on sentences - foreach (var sentence in sentences) - { - sb.Append(sentence); - if (sb.Length > 256) - { - var id = $"{source}+{corpusIndex++}"; - corpus.Add(new CorpusRecord(id, source, sb.ToString())); - sb.Clear(); - } - } - } - - _logger.LogInformation("Load {Count} records into corpus", corpus.Count); - _logger.LogInformation("Loading corpus into memory..."); - - var embeddingService = new SentenceEmbeddingService(corpus); - var collectionName = "knowledge"; - - await _store.CreateCollectionAsync(collectionName); - - var embeddings = await embeddingService.GenerateEmbeddingsAsync(corpus.Select(c => c.Text).ToList()); - var memoryRecords = - Enumerable.Zip(corpus, embeddings) - .Select((tuple) => - { - var (corpusRecord, embedding) = tuple; - var metaData = new MemoryRecordMetadata(true, corpusRecord.Id, corpusRecord.Text, corpusRecord.Source, string.Empty, string.Empty); - var memoryRecord = new MemoryRecord(metaData, embedding, key: corpusRecord.Id); - return memoryRecord; - }); - - _ = await _store.UpsertBatchAsync(collectionName, memoryRecords).ToListAsync(); - } - - public Task RemoveAsync( - string collectionName, - string key, - CancellationToken cancel = default) => _store.RemoveAsync(collectionName, key, cancel); - - public Task RemoveBatchAsync( - string collectionName, - IEnumerable keys, - CancellationToken cancel = default) => _store.RemoveBatchAsync(collectionName, keys, cancel); - - public Task UpsertAsync( - string collectionName, - MemoryRecord record, - CancellationToken cancel = default) => _store.UpsertAsync(collectionName, record, cancel); - - public IAsyncEnumerable UpsertBatchAsync( - string collectionName, - IEnumerable records, - CancellationToken cancel = default) => _store.UpsertBatchAsync(collectionName, records, cancel); -} diff --git a/app/backend/Services/ReadRetrieveReadChatService.cs b/app/backend/Services/ReadRetrieveReadChatService.cs index b1417223..9a750590 100644 --- a/app/backend/Services/ReadRetrieveReadChatService.cs +++ b/app/backend/Services/ReadRetrieveReadChatService.cs @@ -1,61 +1,25 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.AspNetCore.Mvc.ViewFeatures; +using Microsoft.Extensions.Azure; +using Microsoft.SemanticKernel.AI.ChatCompletion; + namespace MinimalApi.Services; public class ReadRetrieveReadChatService { private readonly SearchClient _searchClient; - private readonly AzureOpenAIChatCompletionService _completionService; private readonly IKernel _kernel; private readonly IConfiguration _configuration; - private const string FollowUpQuestionsPrompt = """ - After answering question, also generate three very brief follow-up questions that the user would likely ask next. - Use double angle brackets to reference the questions, e.g. <>. - Try not to repeat questions that have already been asked. - Only generate questions and do not generate any text before or after the questions, such as 'Next Questions' - """; - - private const string AnswerPromptTemplate = """ - <|im_start|>system - You are a system assistant who helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers. - Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. - {{$follow_up_questions_prompt}} - For tabular information return it as an html table. Do not return markdown format. - Each source has a name followed by colon and the actual information, ALWAYS reference source for each fact you use in the response. Use square brakets to reference the source. List each source separately. - {{$injected_prompt}} - - Here're a few examples: - ### Good Example 1 (include source) ### - Apple is a fruit[reference1.pdf]. - ### Good Example 2 (include multiple source) ### - Apple is a fruit[reference1.pdf][reference2.pdf]. - ### Good Example 2 (include source and use double angle brackets to reference question) ### - Microsoft is a software company[reference1.pdf]. <> <> <> - ### END ### - Sources: - {{$sources}} - - Chat history: - {{$chat_history}} - <|im_end|> - <|im_start|>user - {{$question}} - <|im_end|> - <|im_start|>assistant - """; - public ReadRetrieveReadChatService( SearchClient searchClient, - AzureOpenAIChatCompletionService completionService, + OpenAIClient client, IConfiguration configuration) { _searchClient = searchClient; - _completionService = completionService; - var deployedModelName = configuration["AzureOpenAiChatGptDeployment"]; - var kernel = Kernel.Builder.Build(); - kernel.Config.AddTextCompletionService(deployedModelName!, _ => completionService); - _kernel = kernel; + var deployedModelName = configuration["AzureOpenAiChatGptDeployment"] ?? throw new ArgumentNullException(); + _kernel = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client).Build(); _configuration = configuration; } @@ -69,92 +33,122 @@ public async Task ReplyAsync( var useSemanticRanker = overrides?.SemanticRanker ?? false; var excludeCategory = overrides?.ExcludeCategory ?? null; var filter = excludeCategory is null ? null : $"category ne '{excludeCategory}'"; + IChatCompletion chat = _kernel.GetService(); // step 1 // use llm to get query - var queryFunction = CreateQueryPromptFunction(history); - var context = new ContextVariables(); - var historyText = history.GetChatHistoryAsText(includeLastTurn: true); - context["chat_history"] = historyText; - context["question"] = history.LastOrDefault()?.User is { } userQuestion + var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question. +Make your respond simple and precise. Return the query only, do not return any other text. +e.g. +Northwind Health Plus AND standard plan. +standard plan AND dental AND employee benefit. +"); + var question = history.LastOrDefault()?.User is { } userQuestion ? userQuestion : throw new InvalidOperationException("Use question is null"); + getQueryChat.AddUserMessage(question); + var result = await chat.GetChatCompletionsAsync( + getQueryChat, + new ChatRequestSettings + { + Temperature = 0, + MaxTokens = 128, + }, + cancellationToken); + + if (result.Count != 1) + { + throw new InvalidOperationException("Failed to get search query"); + } - var query = await _kernel.RunAsync(context, cancellationToken, queryFunction); + var query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content; // step 2 // use query to search related docs - var documentContents = await _searchClient.QueryDocumentsAsync(query.Result, overrides, cancellationToken); + var documentContents = await _searchClient.QueryDocumentsAsync(query, overrides, cancellationToken); - // step 3 - // use llm to get answer - var answerContext = new ContextVariables(); - ISKFunction answerFunction; - string prompt; - answerContext["chat_history"] = history.GetChatHistoryAsText(); - answerContext["sources"] = documentContents; - answerContext["follow_up_questions_prompt"] = overrides?.SuggestFollowupQuestions is true ? ReadRetrieveReadChatService.FollowUpQuestionsPrompt : string.Empty; - - if (overrides is null or { PromptTemplate: null }) - { - answerContext["$injected_prompt"] = string.Empty; - answerFunction = CreateAnswerPromptFunction(ReadRetrieveReadChatService.AnswerPromptTemplate, overrides); - prompt = ReadRetrieveReadChatService.AnswerPromptTemplate; - } - else if (overrides is not null && overrides.PromptTemplate.StartsWith(">>>")) + if (string.IsNullOrEmpty(documentContents)) { - answerContext["$injected_prompt"] = overrides.PromptTemplate[3..]; - answerFunction = CreateAnswerPromptFunction(ReadRetrieveReadChatService.AnswerPromptTemplate, overrides); - prompt = ReadRetrieveReadChatService.AnswerPromptTemplate; + documentContents = "no source available"; } - else if (overrides?.PromptTemplate is string promptTemplate) + + // step 3 + // put together related docs and conversation history to generate answer + var answerChat = chat.CreateNewChat($@"You are a system assistant who helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers"); + + // add chat history + foreach (var turn in history) { - answerFunction = CreateAnswerPromptFunction(promptTemplate, overrides); - prompt = promptTemplate; + answerChat.AddUserMessage(turn.User); + if (turn.Bot is { } botMessage) + { + answerChat.AddAssistantMessage(botMessage); + } } - else + + // format prompt + answerChat.AddUserMessage(@$" ## Source ## +{documentContents} +## End ## + +You answer needs to be a json object with the following format. +{{ + ""answer"": // the answer to the question, remember to reference the source for each fact you use in the response. e.g. Apple is a fruit[reference1.pdf]. If no source is provided, put the answer as I don't know. + ""thoughts"": // brief thoughts on how you came up with the answer, e.g. what sources you used, what you thought about, etc. +}}"); + + // get answer + var answer = await chat.GetChatCompletionsAsync( + answerChat, + new ChatRequestSettings + { + Temperature = overrides?.Temperature ?? 0.7, + MaxTokens = 1024, + }, + cancellationToken); + var answerJson = answer[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content; + var answerObject = JsonSerializer.Deserialize(answerJson); + var ans = answerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer"); + var thoughts = answerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts"); + + // step 4 + // add follow up questions if requested + if (overrides?.SuggestFollowupQuestions is true) { - throw new InvalidOperationException("Failed to get search result"); + var followUpQuestionChat = chat.CreateNewChat(@"You are a helpful AI assistant"); + followUpQuestionChat.AddUserMessage($@"Generate three follow-up question based on the answer you just generated. +# Answer +{ans} + +# Format of the response +Return the follow-up question as a json string list. +e.g. +[ + ""What is the deductible?"", + ""What is the co-pay?"", + ""What is the out-of-pocket maximum?"" +]"); + + var followUpQuestions = await chat.GetChatCompletionsAsync( + followUpQuestionChat, + new ChatRequestSettings + { + Temperature = 0, + MaxTokens = 256, + }, + cancellationToken); + + var followUpQuestionsJson = followUpQuestions[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content; + var followUpQuestionsObject = JsonSerializer.Deserialize(followUpQuestionsJson); + var followUpQuestionsList = followUpQuestionsObject.EnumerateArray().Select(x => x.GetString()).ToList(); + foreach (var followUpQuestion in followUpQuestionsList) + { + ans += $" <<{followUpQuestion}>> "; + } } - - var ans = await _kernel.RunAsync(answerContext, cancellationToken, answerFunction); - prompt = await _kernel.PromptTemplateEngine.RenderAsync(prompt, ans); return new ApproachResponse( DataPoints: documentContents.Split('\r'), - Answer: ans.Result, - Thoughts: $"Searched for:
{query}

Prompt:
{prompt.Replace("\n", "
")}", + Answer: ans, + Thoughts: thoughts, CitationBaseUrl: _configuration.ToCitationBaseUrl()); } - - private ISKFunction CreateQueryPromptFunction(ChatTurn[] history) - { - var queryPromptTemplate = """ - <|im_start|>system - Chat history: - {{$chat_history}} - - Here's a few examples of good search queries: - ### Good example 1 ### - Northwind Health Plus AND standard plan - ### Good example 2 ### - standard plan AND dental AND employee benefit - ### - - <|im_end|> - <|im_start|>system - Generate search query for followup question. You can refer to chat history for context information. Just return search query and don't include any other information. - {{$question}} - <|im_end|> - <|im_start|>assistant - """; - - return _kernel.CreateSemanticFunction(queryPromptTemplate, - temperature: 0, - maxTokens: 32, - stopSequences: new[] { "<|im_end|>" }); - } - - private ISKFunction CreateAnswerPromptFunction(string answerTemplate, RequestOverrides? overrides) => _kernel.CreateSemanticFunction(answerTemplate, - temperature: overrides?.Temperature ?? 0.7, - maxTokens: 1024, - stopSequences: new[] { "<|im_end|>" }); } diff --git a/app/backend/Services/SentenceEmbeddingService.cs b/app/backend/Services/SentenceEmbeddingService.cs deleted file mode 100644 index e36d3609..00000000 --- a/app/backend/Services/SentenceEmbeddingService.cs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace MinimalApi.Services; - -internal sealed class SentenceEmbeddingService : IEmbeddingGeneration -{ - private readonly MLContext _mlContext; - private PredictionEngine? _predictionEngine; - - public SentenceEmbeddingService(IEnumerable corpusToTrain) - { - _mlContext = new MLContext(0); - Train(corpusToTrain.Select(c => c.Text)); - } - - private void Train(IEnumerable inputs) - { - var featurizeTextOption = new TextFeaturizingEstimator.Options - { - StopWordsRemoverOptions = new StopWordsRemovingEstimator.Options - { - Language = TextFeaturizingEstimator.Language.English, - } - }; - var textFeaturizer = _mlContext.Transforms.Text.FeaturizeText(outputColumnName: "Embedding", featurizeTextOption, "Text"); - var model = textFeaturizer.Fit(_mlContext.Data.LoadFromEnumerable(inputs.Select(i => new { Text = i }))); - _predictionEngine = _mlContext.Model.CreatePredictionEngine(model); - } - - public Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken = default) - { - var outputs = data.Select(i => _predictionEngine!.Predict(new Input { Text = i })).ToList(); - var embeddings = outputs.Select(o => new Embedding(o.Embedding!)).ToList(); - return Task.FromResult>>(embeddings); - } - - private class Input - { - public string Text { get; set; } = string.Empty; - } - - private class Output - { - public float[] Embedding { get; set; } = Array.Empty(); - } -}