Skip to content

Commit

Permalink
Tweaks DnsGroupRacerClient to fault tasks when the DNS response is no…
Browse files Browse the repository at this point in the history
…t successful.
  • Loading branch information
alanedwardes committed Dec 26, 2023
1 parent 6e6c955 commit 9e195ff
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
12 changes: 10 additions & 2 deletions src/Ae.Dns.Client/DnsGroupRacerClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public async Task<DnsMessage> Query(DnsMessage query, CancellationToken token)
var randomisedClients = randomisedGroups.ToDictionary(x => x.Value.OrderBy(y => Guid.NewGuid()).First(), x => x.Key);

// Start the query tasks (and create a lookup from query to client)
var queries = randomisedClients.Keys.ToDictionary(client => client.Query(query, token), client => client);
var queries = randomisedClients.Keys.ToDictionary(client => QueryWrapped(client, query, token), client => client);

// Select a winning task
var winningTask = await TaskRacer.RaceTasks(queries.Select(x => x.Key), async result => result.IsFaulted || (await result).EncounteredResolverError());
var winningTask = await TaskRacer.RaceTasks(queries.Select(x => x.Key));

// If tasks faulted, log the reason
var faultedTasks = queries.Keys.Where(x => x.IsFaulted).ToArray();
Expand All @@ -63,6 +63,7 @@ public async Task<DnsMessage> Query(DnsMessage query, CancellationToken token)
if (winningTask.IsFaulted)
{
_logger.LogError("All tasks using {FaultedClients} from groups {FaultedGroups} failed for query {Query} in {ElapsedMilliseconds}ms", faultedClientsString, faultedGroupsString, query, sw.ElapsedMilliseconds);
return DnsQueryFactory.CreateErrorResponse(query);
}
else
{
Expand All @@ -84,6 +85,13 @@ public async Task<DnsMessage> Query(DnsMessage query, CancellationToken token)
return winningAnswer;
}

private async Task<DnsMessage> QueryWrapped(IDnsClient client, DnsMessage query, CancellationToken token)
{
var answer = await client.Query(query, token);
answer.EnsureSuccessResponseCode();
return answer;
}

/// <inheritdoc/>
public void Dispose()
{
Expand Down
12 changes: 3 additions & 9 deletions src/Ae.Dns.Client/Internal/TaskRacer.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace Ae.Dns.Client.Internal
{
internal static class TaskRacer
{
public static async Task<Task<TResult>> RaceTasks<TResult>(IEnumerable<Task<TResult>> tasks, Func<Task<TResult>, Task<bool>> isFailed)
public static async Task<Task<TResult>> RaceTasks<TResult>(IEnumerable<Task<TResult>> tasks)
{
var queue = tasks.ToList();

Expand All @@ -17,14 +16,9 @@ public static async Task<Task<TResult>> RaceTasks<TResult>(IEnumerable<Task<TRes
task = await Task.WhenAny(queue);
queue.Remove(task);
}
while (await isFailed(task) && queue.Count > 0);
while (task.IsFaulted && queue.Count > 0);

return task;
}

public static async Task<Task<TResult>> RaceTasks<TResult>(IEnumerable<Task<TResult>> tasks)
{
return await RaceTasks(tasks, task => Task.FromResult(task.IsFaulted));
}
}
}
17 changes: 5 additions & 12 deletions src/Ae.Dns.Protocol/DnsMessageExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,14 @@ internal static class DnsMessageExtensions
return null;
}

/// <summary>
/// Returns true if the resolver encountered an error.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public static bool EncounteredResolverError(this DnsMessage message)
public static void EnsureSuccessResponseCode(this DnsMessage message)
{
if (!message.Header.IsQueryResponse)
if (message.Header.ResponseCode == DnsResponseCode.ServFail ||
message.Header.ResponseCode == DnsResponseCode.NotImp ||
message.Header.ResponseCode == DnsResponseCode.NotAuth)
{
return false;
throw new Exception($"The response code {message.Header.ResponseCode} does not indicate success for message {message}");
}

return message.Header.ResponseCode == DnsResponseCode.ServFail ||
message.Header.ResponseCode == DnsResponseCode.NotImp ||
message.Header.ResponseCode == DnsResponseCode.NotAuth;
}

public static bool TryParseIpAddressFromReverseLookup(this DnsMessage message, out IPAddress? address)
Expand Down
17 changes: 17 additions & 0 deletions src/Ae.Dns.Protocol/DnsQueryFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ public static DnsHeader Clone(DnsHeader header)
};
}

internal static DnsMessage CreateErrorResponse(DnsMessage message, DnsResponseCode responseCode = DnsResponseCode.ServFail)
{
return new DnsMessage
{
Header = new DnsHeader
{
Id = message.Header.Id,
Host = message.Header.Host,
QueryType = message.Header.QueryType,
QueryClass = message.Header.QueryClass,
OperationCode = message.Header.OperationCode,
ResponseCode = responseCode,
IsQueryResponse = true
}
};
}

/// <summary>
/// Truncate the answer (for example, if it has overflowed the size of a UDP packet).
/// </summary>
Expand Down

0 comments on commit 9e195ff

Please sign in to comment.