Skip to content

Commit

Permalink
Use AzureEmbedFunctionService from EmbedFunction in PrepDoc (#211)
Browse files Browse the repository at this point in the history
## Purpose
<!-- Describe the intention of the changes being proposed. What problem
does it solve or functionality does it add? -->
* ...

## Does this introduce a breaking change?
<!-- Mark one with an "x". -->
```
[ ] Yes
[ ] No
```

## Pull Request Type
What kind of change does this Pull Request introduce?

<!-- Please check the one that applies to this PR using "x". -->
```
[ ] Bugfix
[ ] Feature
[ ] Code style update (formatting, local variables)
[ ] Refactoring (no functional changes, no api changes)
[ ] Documentation content changes
[ ] Other... Please describe:
```

## How to Test
*  Get the code

```
git clone [repo-address]
cd [repo-name]
git checkout [branch-name]
npm install
```

* Test the code
<!-- Add steps to run the tests suite and/or manually test -->
```
```

## What to Check
Verify that the following are valid
* ...

## Other Information
<!-- Add any other helpful information that may be needed here. -->
  • Loading branch information
LittleLittleCloud authored Nov 1, 2023
1 parent a47ced9 commit eb65d06
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 480 deletions.
1 change: 1 addition & 0 deletions app/functions/EmbedFunctions/EmbedFunctions.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<PackageReference Include="Azure.AI.FormRecognizer" />
<PackageReference Include="Azure.Search.Documents" />
<PackageReference Include="Azure.Storage.Blobs" />
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="Azure.Storage.Files.Shares" />
<PackageReference Include="Azure.Storage.Queues" />
<PackageReference Include="Microsoft.Azure.Functions.Worker.Extensions.Storage" />
Expand Down
5 changes: 2 additions & 3 deletions app/functions/EmbedFunctions/EmbeddingFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public sealed class EmbeddingFunction(
public Task EmbedAsync(
[BlobTrigger(
blobPath: "content/{name}",
Connection = "AzureStorageAccountEndpoint")] Stream blobStream,
string name,
BlobClient client) => embeddingAggregateService.EmbedBlobAsync(client, blobStream, blobName: name);
Connection = "AzureWebJobsStorage")] Stream blobStream,
string name) => embeddingAggregateService.EmbedBlobAsync(blobStream, blobName: name);
}
25 changes: 20 additions & 5 deletions app/functions/EmbedFunctions/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using Azure.AI.OpenAI;
using Microsoft.Extensions.DependencyInjection;

var host = new HostBuilder()
.ConfigureServices(services =>
{
Expand Down Expand Up @@ -36,7 +39,7 @@ uri is not null
services.AddSingleton<BlobContainerClient>(_ =>
{
var blobServiceClient = new BlobServiceClient(
GetUriFromEnvironment("AZURE_STORAGE_ACCOUNT_ENDPOINT"),
GetUriFromEnvironment("AZURE_STORAGE_BLOB_ENDPOINT"),
credential);

return blobServiceClient.GetBlobContainerClient("corpus");
Expand All @@ -45,10 +48,22 @@ uri is not null
services.AddSingleton<EmbedServiceFactory>();
services.AddSingleton<EmbeddingAggregateService>();

services.AddSingleton<IEmbedService, AzureSearchEmbedService>();
services.AddSingleton<IEmbedService, PineconeEmbedService>();
services.AddSingleton<IEmbedService, QdrantEmbedService>();
services.AddSingleton<IEmbedService, MilvusEmbedService>();
services.AddSingleton<IEmbedService, AzureSearchEmbedService>(provider =>
{
var searchIndexName = Environment.GetEnvironmentVariable("AZURE_SEARCH_INDEX") ?? throw new ArgumentNullException("AZURE_SEARCH_INDEX is null");
var embeddingModelName = Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new ArgumentNullException("AZURE_OPENAI_EMBEDDING_DEPLOYMENT is null");
var openaiEndPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentNullException("AZURE_OPENAI_ENDPOINT is null");

var openAIClient = new OpenAIClient(new Uri(openaiEndPoint), new DefaultAzureCredential());

var searchClient = provider.GetRequiredService<SearchClient>();
var searchIndexClient = provider.GetRequiredService<SearchIndexClient>();
var blobContainerClient = provider.GetRequiredService<BlobContainerClient>();
var documentClient = provider.GetRequiredService<DocumentAnalysisClient>();
var logger = provider.GetRequiredService<ILogger<AzureSearchEmbedService>>();

return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, logger);
});
})
.ConfigureFunctionsWorkerDefaults()
.Build();
Expand Down
126 changes: 80 additions & 46 deletions app/functions/EmbedFunctions/Services/AzureSearchEmbedService.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
// Copyright (c) Microsoft. All rights reserved.

using Azure.AI.OpenAI;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Options;

namespace EmbedFunctions.Services;

internal sealed partial class AzureSearchEmbedService(
public sealed partial class AzureSearchEmbedService(
OpenAIClient openAIClient,
string embeddingModelName,
SearchClient indexSectionClient,
string searchIndexName,
SearchIndexClient searchIndexClient,
DocumentAnalysisClient documentAnalysisClient,
BlobContainerClient corpusContainerClient,
ILogger<AzureSearchEmbedService> logger) : IEmbedService
ILogger<AzureSearchEmbedService>? logger) : IEmbedService
{
[GeneratedRegex("[^0-9a-zA-Z_-]")]
private static partial Regex MatchInSetRegex();
Expand All @@ -16,9 +23,6 @@ public async Task<bool> EmbedBlobAsync(Stream blobStream, string blobName)
{
try
{
var searchIndexName = Environment.GetEnvironmentVariable(
"AZURE_SEARCH_INDEX") ?? "gptkbindex";

await EnsureSearchIndexAsync(searchIndexName);

var pageMap = await GetDocumentTextAsync(blobStream, blobName);
Expand All @@ -41,67 +45,94 @@ public async Task<bool> EmbedBlobAsync(Stream blobStream, string blobName)
}
catch (Exception exception)
{
logger.LogError(
logger?.LogError(
exception, "Failed to embed blob '{BlobName}'", blobName);

return false;
}
}

private async Task EnsureSearchIndexAsync(string searchIndexName)
public async Task CreateSearchIndexAsync(string searchIndexName)
{
var indexNames = searchIndexClient.GetIndexNamesAsync();
await foreach (var page in indexNames.AsPages())
string vectorSearchConfigName = "my-vector-config";
string vectorSearchProfile = "my-vector-profile";
var index = new SearchIndex(searchIndexName)
{
if (page.Values.Any(indexName => indexName == searchIndexName))
VectorSearch = new()
{
logger.LogWarning(
"Search index '{SearchIndexName}' already exists", searchIndexName);
return;
Algorithms =
{
new HnswVectorSearchAlgorithmConfiguration(vectorSearchConfigName)
},
Profiles =
{
new VectorSearchProfile(vectorSearchProfile, vectorSearchConfigName)
}
}

var index = new SearchIndex(searchIndexName)
{
},
Fields =
{
new SimpleField("id", SearchFieldDataType.String) { IsKey = true },
new SearchableField("content") { AnalyzerName = LexicalAnalyzerName.EnMicrosoft },
new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true },
new SearchField("embedding", SearchFieldDataType.Collection(SearchFieldDataType.Single))
{
new SimpleField("id", SearchFieldDataType.String) { IsKey = true },
new SearchableField("content") { AnalyzerName = "en.microsoft" },
new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true },
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true }
},
VectorSearchDimensions = 1536,
IsSearchable = true,
VectorSearchProfile = vectorSearchProfile,
}
},
SemanticSettings = new SemanticSettings
{
Configurations =
{
new SemanticConfiguration("default", new PrioritizedFields
{
new SemanticConfiguration("default", new PrioritizedFields
ContentFields =
{
ContentFields =
new SemanticField
{
new SemanticField
{
FieldName = "content"
}
FieldName = "content"
}
})
}
})
}
}
};

logger.LogInformation(
"Creating '{searchIndexName}' search index", searchIndexName);
logger?.LogInformation(
"Creating '{searchIndexName}' search index", searchIndexName);

await searchIndexClient.CreateIndexAsync(index);
}

public async Task EnsureSearchIndexAsync(string searchIndexName)
{
var indexNames = searchIndexClient.GetIndexNamesAsync();
await foreach (var page in indexNames.AsPages())
{
if (page.Values.Any(indexName => indexName == searchIndexName))
{
logger?.LogWarning(
"Search index '{SearchIndexName}' already exists", searchIndexName);
return;
}
}

await CreateSearchIndexAsync(searchIndexName);
}

private async Task<IReadOnlyList<PageDetail>> GetDocumentTextAsync(Stream blobStream, string blobName)
{
logger.LogInformation(
logger?.LogInformation(
"Extracting text from '{Blob}' using Azure Form Recognizer", blobName);

using var ms = new MemoryStream();
blobStream.CopyTo(ms);
ms.Position = 0;
AnalyzeDocumentOperation operation = documentAnalysisClient.AnalyzeDocument(
WaitUntil.Started, "prebuilt-layout", blobStream);
WaitUntil.Started, "prebuilt-layout", ms);

var offset = 0;
List<PageDetail> pageMap = [];
Expand Down Expand Up @@ -208,7 +239,7 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text)
return;
}

logger.LogInformation("Uploading corpus '{CorpusBlobName}'", corpusBlobName);
logger?.LogInformation("Uploading corpus '{CorpusBlobName}'", corpusBlobName);

await using var stream = new MemoryStream(Encoding.UTF8.GetBytes(text));
await blobClient.UploadAsync(stream, new BlobHttpHeaders
Expand All @@ -231,7 +262,7 @@ private IEnumerable<Section> CreateSections(
var start = 0;
var end = length;

logger.LogInformation("Splitting '{BlobName}' into sections", blobName);
logger?.LogInformation("Splitting '{BlobName}' into sections", blobName);

while (start + SectionOverlap < length)
{
Expand Down Expand Up @@ -300,9 +331,9 @@ private IEnumerable<Section> CreateSections(
// If the section ends with an unclosed table, we need to start the next section with the table.
// If table starts inside SentenceSearchLimit, we ignore it, as that will cause an infinite loop for tables longer than MaxSectionLength
// If last table starts inside SectionOverlap, keep overlapping
if (logger.IsEnabled(LogLevel.Warning))
if (logger?.IsEnabled(LogLevel.Warning) is true)
{
logger.LogWarning("""
logger?.LogWarning("""
Section ends with unclosed table, starting next section with the
table at page {Offset} offset {Start} table start {LastTableStart}
""",
Expand Down Expand Up @@ -349,10 +380,10 @@ private static string BlobNameFromFilePage(string blobName, int page = 0) => Pat

private async Task IndexSectionsAsync(string searchIndexName, IEnumerable<Section> sections, string blobName)
{
var infoLoggingEnabled = logger.IsEnabled(LogLevel.Information);
if (infoLoggingEnabled)
var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information);
if (infoLoggingEnabled is true)
{
logger.LogInformation("""
logger?.LogInformation("""
Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
""",
blobName,
Expand All @@ -363,6 +394,8 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
var batch = new IndexDocumentsBatch<SearchDocument>();
foreach (var section in sections)
{
var embeddings = await openAIClient.GetEmbeddingsAsync(embeddingModelName, new Azure.AI.OpenAI.EmbeddingsOptions(section.Content.Replace('\r', ' ')));
var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? [];
batch.Actions.Add(new IndexDocumentsAction<SearchDocument>(
IndexActionType.MergeOrUpload,
new SearchDocument
Expand All @@ -371,7 +404,8 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
["content"] = section.Content,
["category"] = section.Category,
["sourcepage"] = section.SourcePage,
["sourcefile"] = section.SourceFile
["sourcefile"] = section.SourceFile,
["embedding"] = embedding,
}));

iteration++;
Expand All @@ -380,9 +414,9 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
// Every one thousand documents, batch create.
IndexDocumentsResult result = await indexSectionClient.IndexDocumentsAsync(batch);
int succeeded = result.Results.Count(r => r.Succeeded);
if (infoLoggingEnabled)
if (infoLoggingEnabled is true)
{
logger.LogInformation("""
logger?.LogInformation("""
Indexed {Count} sections, {Succeeded} succeeded
""",
batch.Actions.Count,
Expand All @@ -399,9 +433,9 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
var index = new SearchIndex($"index-{batch.Actions.Count}");
IndexDocumentsResult result = await indexSectionClient.IndexDocumentsAsync(batch);
int succeeded = result.Results.Count(r => r.Succeeded);
if (logger.IsEnabled(LogLevel.Information))
if (logger?.IsEnabled(LogLevel.Information) is true)
{
logger.LogInformation("""
logger?.LogInformation("""
Indexed {Count} sections, {Succeeded} succeeded
""",
batch.Actions.Count,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Copyright (c) Microsoft. All rights reserved.

using System.IO;

namespace EmbedFunctions.Services;

public sealed class EmbeddingAggregateService(
EmbedServiceFactory embedServiceFactory,
BlobContainerClient client,
ILogger<EmbeddingAggregateService> logger)
{
internal async Task EmbedBlobAsync(BlobClient client, Stream blobStream, string blobName)
internal async Task EmbedBlobAsync(Stream blobStream, string blobName)
{
try
{
Expand Down
6 changes: 0 additions & 6 deletions app/prepdocs/PrepareDocs/PageDetail.cs

This file was deleted.

4 changes: 4 additions & 0 deletions app/prepdocs/PrepareDocs/PrepareDocs.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@
<PackageReference Include="PdfSharpCore" />
<PackageReference Include="System.CommandLine" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\functions\EmbedFunctions\EmbedFunctions.csproj" />
</ItemGroup>

</Project>
18 changes: 18 additions & 0 deletions app/prepdocs/PrepareDocs/Program.Clients.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.


using EmbedFunctions.Services;
using Microsoft.Extensions.Logging;

internal static partial class Program
{
private static BlobContainerClient? s_corpusContainerClient;
Expand All @@ -16,6 +19,21 @@ internal static partial class Program
private static readonly SemaphoreSlim s_searchIndexLock = new(1);
private static readonly SemaphoreSlim s_searchLock = new(1);
private static readonly SemaphoreSlim s_openAILock = new(1);
private static readonly SemaphoreSlim s_embeddingLock = new(1);

private static Task<AzureSearchEmbedService> GetAzureSearchEmbedService(AppOptions options) =>
GetLazyClientAsync<AzureSearchEmbedService>(options, s_embeddingLock, async o =>
{
var searchIndexClient = await GetSearchIndexClientAsync(o);
var searchClient = await GetSearchClientAsync(o);
var documentClient = await GetFormRecognizerClientAsync(o);
var blobContainerClient = await GetBlobContainerClientAsync(o);
var openAIClient = await GetAzureOpenAIClientAsync(o);
var embeddingModelName = o.EmbeddingModelName ?? throw new ArgumentNullException(nameof(o.EmbeddingModelName));
var searchIndexName = o.SearchIndexName ?? throw new ArgumentNullException(nameof(o.SearchIndexName));

return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, null);
});

private static Task<BlobContainerClient> GetCorpusBlobContainerClientAsync(AppOptions options) =>
GetLazyClientAsync<BlobContainerClient>(options, s_corpusContainerLock, static async o =>
Expand Down
Loading

0 comments on commit eb65d06

Please sign in to comment.