diff --git a/Core/Logging/LoggerExtensions.LoggedEnumerable.cs b/Core/Logging/LoggerExtensions.LoggedEnumerable.cs new file mode 100644 index 0000000..04a5056 --- /dev/null +++ b/Core/Logging/LoggerExtensions.LoggedEnumerable.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +using Microsoft.Extensions.Logging; + +namespace MonoDevelop.Xml.Logging; + +public static partial class LoggerExtensions +{ + class LoggedEnumerable : IEnumerable + { + readonly IEnumerable enumerable; + readonly ILogger logger; + readonly string? originMember; + + public LoggedEnumerable (IEnumerable enumerable, ILogger logger, string originMember) + { + this.enumerable = enumerable; + this.logger = logger; + this.originMember = originMember; + } + + void Log (Exception ex) => logger.LogInternalError (ex, originMember); + + public IEnumerator GetEnumerator () + { + try { + var enumerator = enumerable.GetEnumerator (); + return new LoggedEnumerator (this, enumerator); + } catch (Exception ex) { + Log (ex); + throw; + } + } + + IEnumerator IEnumerable.GetEnumerator () => GetEnumerator (); + + class LoggedEnumerator : IEnumerator + { + readonly LoggedEnumerable parent; + readonly IEnumerator enumerator; + + public LoggedEnumerator (LoggedEnumerable parent, IEnumerator enumerator) + { + this.parent = parent; + this.enumerator = enumerator; + } + + public T Current { + get { + try { + return enumerator.Current; + } catch (Exception ex) { + parent.Log (ex); + throw; + } + } + } + + object? IEnumerator.Current => Current; + + public void Dispose () + { + try { + enumerator.Dispose (); + } catch (Exception ex) { + parent.Log (ex); + throw; + } + } + + public bool MoveNext () + { + try { + return enumerator.MoveNext (); + } catch (Exception ex) { + parent.Log (ex); + return false; + } + } + + public void Reset () + { + try { + enumerator.Reset (); + } catch (Exception ex) { + parent.Log (ex); + } + } + } + } +} \ No newline at end of file diff --git a/Core/Logging/LoggerExtensions.cs b/Core/Logging/LoggerExtensions.cs index b42aa60..54ad8b8 100644 --- a/Core/Logging/LoggerExtensions.cs +++ b/Core/Logging/LoggerExtensions.cs @@ -20,21 +20,37 @@ public static void LogExceptionsAndForget (this Task task, ILogger logger, [Call task.CatchAndLogIfFaulted (logger, originMember); } - public static Task WithExceptionLogger (this Task task, ILogger logger, [CallerMemberName] string? originMember = default) + /// + /// Attaches a continution to the task that logs any exception thrown by the task, and returns the task. + /// + public static Task WithExceptionLogger (this Task task, ILogger logger, [CallerMemberName] string? originMember = default) { task.CatchAndLogIfFaulted (logger, originMember); return task; } - /// - /// Attaches a continution to the task that logs any exception thrown by the task, and returns the task. - /// - public static Task WithExceptionLogger (this Task task, ILogger logger, [CallerMemberName] string? originMember = default) + public static Task WithExceptionLogger (this Task task, ILogger logger, [CallerMemberName] string? originMember = default) { task.CatchAndLogIfFaulted (logger, originMember); return task; } + public static Task> WithExceptionLogger (this Task> task, ILogger logger, [CallerMemberName] string? originMember = default) + { + return task.ContinueWith (t => { + try { + return t.Result.WithExceptionLogger (logger, originMember); + } catch (Exception ex) { + LogInternalError (logger, ex, originMember); + throw; + } + }, + TaskContinuationOptions.ExecuteSynchronously); + } + + public static IEnumerable WithExceptionLogger (this IEnumerable enumerable, ILogger logger, [CallerMemberName] string? originMember = default) + => new LoggedEnumerable (enumerable, logger, originMember ?? throw new ArgumentNullException (nameof (originMember))); + static Task CatchAndLogIfFaulted (this Task task, ILogger logger, string? originMember) { if (originMember is null) { @@ -44,7 +60,7 @@ static Task CatchAndLogIfFaulted (this Task task, ILogger logger, string? origin _ = task.ContinueWith ( t => LogExceptions (logger, originMember, t), default, - TaskContinuationOptions.OnlyOnFaulted, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default );