Skip to content

Commit

Permalink
Added Query iterator method (#80)
Browse files Browse the repository at this point in the history
* Added Query iterator method

* Added unexpected case handling
  • Loading branch information
BlackGad authored Jul 10, 2024
1 parent 475595b commit 30199ee
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 39 deletions.
20 changes: 20 additions & 0 deletions Milvus.Client/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ internal static class Constants
/// </summary>
internal const string Offset = "offset";

/// <summary>
/// Top parameter key name.
/// </summary>
internal const string Limit = "limit";

/// <summary>
/// Top parameter key name.
/// </summary>
internal const string BatchSize = "batch_size";

/// <summary>
/// Top parameter key name.
/// </summary>
internal const string Iterator = "iterator";

/// <summary>
/// Top parameter key name.
/// </summary>
internal const string ReduceStopForBest = "reduce_stop_for_best";

/// <summary>
/// Key name in parameters.<see cref="Client.IndexType"/>
/// </summary>
Expand Down
229 changes: 190 additions & 39 deletions Milvus.Client/MilvusCollection.Entity.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Buffers;
using System.Buffers.Binary;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Json;
using Google.Protobuf.Collections;
Expand Down Expand Up @@ -427,61 +427,144 @@ public async Task<IReadOnlyList<FieldData>> QueryAsync(
Expr = expression
};

if (parameters is not null)
PopulateQueryRequestFromParameters(request, parameters);

var response = await _client.InvokeAsync(
_client.GrpcClient.QueryAsync,
request,
static r => r.Status,
cancellationToken)
.ConfigureAwait(false);

return ProcessReturnedFieldData(response.FieldsData);
}

/// <summary>
/// Retrieves rows from a collection via scalar filtering based on a boolean expression using iterator.
/// </summary>
/// <param name="expression">A boolean expression determining which rows are to be returned.</param>
/// <param name="batchSize">Batch size that will be used for every iteration request. Must be between 1 and 16384.</param>
/// <param name="parameters">Various additional optional parameters to configure the query.</param>
/// <param name="cancellationToken">
/// The token to monitor for cancellation requests. The default value is <see cref="CancellationToken.None" />.
/// </param>
/// <returns>A list of <see cref="FieldData{TData}" /> instances with the query results.</returns>
public async IAsyncEnumerable<IReadOnlyList<FieldData>> QueryWithIteratorAsync(
string? expression = null,
int batchSize = 1000,
QueryParameters? parameters = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if ((parameters?.Offset ?? 0) != 0)
{
if (parameters.TimeTravelTimestamp is not null)
throw new MilvusException("Not support offset when searching iteration");
}

var describeResponse = await _client.InvokeAsync(
_client.GrpcClient.DescribeCollectionAsync,
new DescribeCollectionRequest { CollectionName = Name },
r => r.Status,
cancellationToken)
.ConfigureAwait(false);

var pkField = describeResponse.Schema.Fields.FirstOrDefault(x => x.IsPrimaryKey);
if (pkField == null)
{
throw new MilvusException("Schema must contain pk field");
}

var isUserRequestPkField = parameters?.OutputFieldsInternal?.Contains(pkField.Name) ?? false;
var userExpression = expression;
var userLimit = parameters?.Limit ?? int.MaxValue;

QueryRequest request = new()
{
CollectionName = Name,
Expr = (userExpression, pkField) switch
{
request.TravelTimestamp = parameters.TimeTravelTimestamp.Value;
// If user expression is not null, we should use it
{userExpression: not null} => userExpression,
// If user expression is null and pk field is string
{pkField.DataType: DataType.VarChar} => $"{pkField.Name} != ''",
// If user expression is null and pk field is not string
_ => $"{pkField.Name} < {long.MaxValue}",
}
};

if (parameters.PartitionNamesInternal?.Count > 0)
PopulateQueryRequestFromParameters(request, parameters);

// Request id field in any case to proceed with an iterations
if (!isUserRequestPkField) request.OutputFields.Add(pkField.Name);

// Replace parameters required for iterator
ReplaceKeyValueItems(request.QueryParams,
new Grpc.KeyValuePair {Key = Constants.Iterator, Value = "True"},
new Grpc.KeyValuePair {Key = Constants.ReduceStopForBest, Value = "True"},
new Grpc.KeyValuePair {Key = Constants.BatchSize, Value = batchSize.ToString(CultureInfo.InvariantCulture)},
new Grpc.KeyValuePair {Key = Constants.Offset, Value = 0.ToString(CultureInfo.InvariantCulture)},
new Grpc.KeyValuePair {Key = Constants.Limit, Value = Math.Min(batchSize, userLimit).ToString(CultureInfo.InvariantCulture)});

var processedItemsCount = 0;
while (!cancellationToken.IsCancellationRequested)
{
var response = await _client.InvokeAsync(
_client.GrpcClient.QueryAsync,
request,
static r => r.Status,
cancellationToken)
.ConfigureAwait(false);

object? pkLastValue;
int processedDuringIterationCount;
var pkFieldsData = response.FieldsData.Single(x => x.FieldId == pkField.FieldID);
if (pkField.DataType == DataType.VarChar)
{
request.PartitionNames.AddRange(parameters.PartitionNamesInternal);
pkLastValue = pkFieldsData.Scalars.StringData.Data.LastOrDefault();
processedDuringIterationCount = pkFieldsData.Scalars.StringData.Data.Count;
}

if (parameters.OutputFieldsInternal?.Count > 0)
else
{
request.OutputFields.AddRange(parameters.OutputFieldsInternal);
pkLastValue = pkFieldsData.Scalars.IntData.Data.LastOrDefault();
processedDuringIterationCount = pkFieldsData.Scalars.IntData.Data.Count;
}

if (parameters.Offset is not null)
// If there are no more items to process, we should break the loop
if(processedDuringIterationCount == 0) yield break;

// Respond with processed data
if (!isUserRequestPkField)
{
request.QueryParams.Add(new Grpc.KeyValuePair
{
Key = Constants.Offset,
Value = parameters.Offset.Value.ToString(CultureInfo.InvariantCulture)
});
// Filter out extra field if user didn't request it
response.FieldsData.Remove(pkFieldsData);
}
yield return ProcessReturnedFieldData(response.FieldsData);

if (parameters.Limit is not null)
{
request.QueryParams.Add(new Grpc.KeyValuePair
processedItemsCount += processedDuringIterationCount;
var leftItemsCount = userLimit - processedItemsCount;

// If user limit is reached, we should break the loop
if(leftItemsCount <= 0) yield break;

// Setup next iteration limit and expression
ReplaceKeyValueItems(
request.QueryParams,
new Grpc.KeyValuePair
{
Key = "limit", Value = parameters.Limit.Value.ToString(CultureInfo.InvariantCulture)
Key = Constants.Limit,
Value = Math.Min(batchSize, leftItemsCount).ToString(CultureInfo.InvariantCulture)
});
}
}

// Note that we send both the consistency level and the guarantee timestamp, although the latter is derived
// from the former and should be sufficient.
if (parameters?.ConsistencyLevel is null)
{
request.UseDefaultConsistency = true;
request.GuaranteeTimestamp = CalculateGuaranteeTimestamp(Name, ConsistencyLevel.Session, userProvidedGuaranteeTimestamp: null);
}
else
{
request.ConsistencyLevel = (Grpc.ConsistencyLevel)parameters.ConsistencyLevel.Value;
request.GuaranteeTimestamp =
CalculateGuaranteeTimestamp(Name, parameters.ConsistencyLevel.Value,
parameters.GuaranteeTimestamp);
}
var nextExpression = pkField.DataType == DataType.VarChar
? $"{pkField.Name} > '{pkLastValue}'"
: $"{pkField.Name} > {pkLastValue}";

QueryResults response =
await _client.InvokeAsync(_client.GrpcClient.QueryAsync, request, static r => r.Status, cancellationToken)
.ConfigureAwait(false);
if (!string.IsNullOrWhiteSpace(userExpression))
{
nextExpression += $" and {userExpression}";
}

return ProcessReturnedFieldData(response.FieldsData);
request.Expr = nextExpression;
}
}

/// <summary>
Expand Down Expand Up @@ -694,4 +777,72 @@ ulong CalculateGuaranteeTimestamp(

return guaranteeTimestamp;
}

private static void ReplaceKeyValueItems(RepeatedField<Grpc.KeyValuePair> collection, params Grpc.KeyValuePair[] pairs)
{
var obsoleteParameterKeys = pairs.Select(x => x.Key).Distinct().ToArray();
var obsoleteParameters = collection.Where(x => obsoleteParameterKeys.Contains(x.Key)).ToArray();
foreach (var field in obsoleteParameters)
{
collection.Remove(field);
}

foreach (var pair in pairs)
{
collection.Add(pair);
}
}

private void PopulateQueryRequestFromParameters(QueryRequest request, QueryParameters? parameters)
{
if (parameters is not null)
{
if (parameters.TimeTravelTimestamp is not null)
{
request.TravelTimestamp = parameters.TimeTravelTimestamp.Value;
}

if (parameters.PartitionNamesInternal?.Count > 0)
{
request.PartitionNames.AddRange(parameters.PartitionNamesInternal);
}

if (parameters.OutputFieldsInternal?.Count > 0)
{
request.OutputFields.AddRange(parameters.OutputFieldsInternal);
}

if (parameters.Offset is not null)
{
request.QueryParams.Add(new Grpc.KeyValuePair
{
Key = Constants.Offset,
Value = parameters.Offset.Value.ToString(CultureInfo.InvariantCulture)
});
}

if (parameters.Limit is not null)
{
request.QueryParams.Add(new Grpc.KeyValuePair
{
Key = Constants.Limit, Value = parameters.Limit.Value.ToString(CultureInfo.InvariantCulture)
});
}
}

// Note that we send both the consistency level and the guarantee timestamp, although the latter is derived
// from the former and should be sufficient.
if (parameters?.ConsistencyLevel is null)
{
request.UseDefaultConsistency = true;
request.GuaranteeTimestamp = CalculateGuaranteeTimestamp(Name, ConsistencyLevel.Session, userProvidedGuaranteeTimestamp: null);
}
else
{
request.ConsistencyLevel = (Grpc.ConsistencyLevel)parameters.ConsistencyLevel.Value;
request.GuaranteeTimestamp =
CalculateGuaranteeTimestamp(Name, parameters.ConsistencyLevel.Value,
parameters.GuaranteeTimestamp);
}
}
}

0 comments on commit 30199ee

Please sign in to comment.