diff --git a/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs index 881b679a19f3ae..0c22c663d32a58 100644 --- a/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs +++ b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs @@ -73,6 +73,8 @@ public static partial class AsyncEnumerable public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> elementSelector, System.Func, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func elementSelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } public static System.Collections.Generic.IAsyncEnumerable GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Func, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } public static System.Collections.Generic.IAsyncEnumerable GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } public static System.Collections.Generic.IAsyncEnumerable<(int Index, TSource Item)> Index(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs index 3673aa85956f52..af56032f50f520 100644 --- a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections; using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Threading; @@ -10,6 +11,120 @@ namespace System.Linq { public static partial class AsyncEnumerable { + /// Correlates the elements of two sequences based on key equality and groups the results. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// An to use to hash and compare keys. + /// + /// An that contains elements of type + /// where each grouping contains the outer element as the key and the matching inner elements. + /// + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable> GroupJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + IEqualityComparer? comparer = null) + { + ArgumentNullException.ThrowIfNull(outer); + ArgumentNullException.ThrowIfNull(inner); + ArgumentNullException.ThrowIfNull(outerKeySelector); + ArgumentNullException.ThrowIfNull(innerKeySelector); + + return + outer.IsKnownEmpty() ? Empty>() : + Impl(outer, inner, outerKeySelector, innerKeySelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await using IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + + if (await e.MoveNextAsync()) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken); + do + { + TOuter item = e.Current; + yield return new AsyncGroupJoinGrouping(item, lookup[outerKeySelector(item)]); + } + while (await e.MoveNextAsync()); + } + } + } + + /// Correlates the elements of two sequences based on key equality and groups the results. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// An to use to hash and compare keys. + /// + /// An that contains elements of type + /// where each grouping contains the outer element as the key and the matching inner elements. + /// + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable> GroupJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + IEqualityComparer? comparer = null) + { + ArgumentNullException.ThrowIfNull(outer); + ArgumentNullException.ThrowIfNull(inner); + ArgumentNullException.ThrowIfNull(outerKeySelector); + ArgumentNullException.ThrowIfNull(innerKeySelector); + + return + outer.IsKnownEmpty() ? Empty>() : + Impl(outer, inner, outerKeySelector, innerKeySelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await using IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + + if (await e.MoveNextAsync()) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken); + do + { + TOuter item = e.Current; + yield return new AsyncGroupJoinGrouping( + item, + lookup[await outerKeySelector(item, cancellationToken)]); + } + while (await e.MoveNextAsync()); + } + } + } + /// Correlates the elements of two sequences based on key equality and groups the results. /// /// @@ -143,4 +258,22 @@ lookup[await outerKeySelector(item, cancellationToken)], } } } + + internal sealed class AsyncGroupJoinGrouping : IGrouping + { + private readonly TKey _key; + private readonly IEnumerable _elements; + + public AsyncGroupJoinGrouping(TKey key, IEnumerable elements) + { + _key = key; + _elements = elements; + } + + public TKey Key => _key; + + public IEnumerator GetEnumerator() => _elements.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs index 305faf3c77681f..e87e1de4f5c0ce 100644 --- a/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs @@ -26,6 +26,20 @@ public void InvalidInputs_Throws() AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, (Func, CancellationToken, ValueTask>)null)); } + [Fact] + public void InvalidInputs_WithoutResultSelector_Throws() + { + AssertExtensions.Throws("outer", () => AsyncEnumerable.GroupJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), outer => outer, inner => inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, outer => outer, inner => inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null, inner => inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, (Func)null)); + + AssertExtensions.Throws("outer", () => AsyncEnumerable.GroupJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (outer, ct) => outer, async (inner, ct) => inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null, async (inner, ct) => inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, (Func>)null)); + } + [Fact] public void Empty_ProducesEmpty() // validating an optimization / implementation detail { @@ -33,6 +47,13 @@ public void Empty_ProducesEmpty() // validating an optimization / implementation Assert.Same(AsyncEnumerable.Empty(), AsyncEnumerable.Empty().GroupJoin(CreateSource(1, 2, 3), async (s, ct) => s, async (i, ct) => i.ToString(), async (s, e, ct) => s)); } + [Fact] + public void Empty_WithoutResultSelector_ProducesEmpty() + { + Assert.Same(AsyncEnumerable.Empty>(), AsyncEnumerable.Empty().GroupJoin(CreateSource(1, 2, 3), s => s, i => i.ToString())); + Assert.Same(AsyncEnumerable.Empty>(), AsyncEnumerable.Empty().GroupJoin(CreateSource(1, 2, 3), async (s, ct) => s, async (i, ct) => i.ToString())); + } + [Fact] public async Task VariousValues_MatchesEnumerable_String() { @@ -55,6 +76,36 @@ await AssertEqual( } } + [Fact] + public async Task VariousValues_WithoutResultSelector_MatchesEnumerable() + { + int[] outer = [1, 2, 3]; + int[] inner = [1, 2, 2, 3, 3, 3]; + + foreach (IAsyncEnumerable outerSource in CreateSources(outer)) + foreach (IAsyncEnumerable innerSource in CreateSources(inner)) + { + var expected = outer.GroupJoin(inner, o => o, i => i); + var result = await outerSource.GroupJoin(innerSource, o => o, i => i).ToListAsync(); + + Assert.Equal(expected.Count(), result.Count); + foreach (var (exp, act) in expected.Zip(result)) + { + Assert.Equal(exp.Key, act.Key); + Assert.Equal(exp.ToList(), act.ToList()); + } + + var resultAsync = await outerSource.GroupJoin(innerSource, async (o, ct) => o, async (i, ct) => i).ToListAsync(); + + Assert.Equal(expected.Count(), resultAsync.Count); + foreach (var (exp, act) in expected.Zip(resultAsync)) + { + Assert.Equal(exp.Key, act.Key); + Assert.Equal(exp.ToList(), act.ToList()); + } + } + } + [Fact] public async Task Cancellation_Cancels() { @@ -167,5 +218,31 @@ public async Task InterfaceCalls_ExpectedCounts() Assert.Equal(4, inner.CurrentCount); Assert.Equal(1, inner.DisposeAsyncCount); } + + [Fact] + public async Task InterfaceCalls_WithoutResultSelector_ExpectedCounts() + { + TrackingAsyncEnumerable outer, inner; + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.GroupJoin(inner, outer => outer, inner => inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.GroupJoin(inner, async (outer, ct) => outer, async (inner, ct) => inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + } } } diff --git a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs index 2d63b858f0bd26..0cf1cdecae6c38 100644 --- a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs +++ b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs @@ -108,6 +108,8 @@ public static partial class Queryable public static System.Linq.IQueryable GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Linq.Expressions.Expression, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Linq.Expressions.Expression> elementSelector, System.Linq.Expressions.Expression, TResult>> resultSelector) { throw null; } public static System.Linq.IQueryable GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Linq.Expressions.Expression> elementSelector, System.Linq.Expressions.Expression, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } + public static System.Linq.IQueryable> GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector) { throw null; } + public static System.Linq.IQueryable> GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression, TResult>> resultSelector) { throw null; } public static System.Linq.IQueryable GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable<(int Index, TSource Item)> Index(this System.Linq.IQueryable source) { throw null; } diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index 1290b8f6819e23..fde668fbaa2597 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -408,6 +408,36 @@ public static IQueryable Join(this IQuer outer.Expression, GetSourceExpression(inner), Expression.Quote(outerKeySelector), Expression.Quote(innerKeySelector), Expression.Quote(resultSelector), Expression.Constant(comparer, typeof(IEqualityComparer)))); } + [DynamicDependency("GroupJoin`3", typeof(Enumerable))] + public static IQueryable> GroupJoin(this IQueryable outer, IEnumerable inner, Expression> outerKeySelector, Expression> innerKeySelector) + { + ArgumentNullException.ThrowIfNull(outer); + ArgumentNullException.ThrowIfNull(inner); + ArgumentNullException.ThrowIfNull(outerKeySelector); + ArgumentNullException.ThrowIfNull(innerKeySelector); + + return outer.Provider.CreateQuery>( + Expression.Call( + null, + new Func, IEnumerable, Expression>, Expression>, IQueryable>>(GroupJoin).Method, + outer.Expression, GetSourceExpression(inner), Expression.Quote(outerKeySelector), Expression.Quote(innerKeySelector))); + } + + [DynamicDependency("GroupJoin`3", typeof(Enumerable))] + public static IQueryable> GroupJoin(this IQueryable outer, IEnumerable inner, Expression> outerKeySelector, Expression> innerKeySelector, IEqualityComparer? comparer) + { + ArgumentNullException.ThrowIfNull(outer); + ArgumentNullException.ThrowIfNull(inner); + ArgumentNullException.ThrowIfNull(outerKeySelector); + ArgumentNullException.ThrowIfNull(innerKeySelector); + + return outer.Provider.CreateQuery>( + Expression.Call( + null, + new Func, IEnumerable, Expression>, Expression>, IEqualityComparer, IQueryable>>(GroupJoin).Method, + outer.Expression, GetSourceExpression(inner), Expression.Quote(outerKeySelector), Expression.Quote(innerKeySelector), Expression.Constant(comparer, typeof(IEqualityComparer)))); + } + [DynamicDependency("GroupJoin`4", typeof(Enumerable))] public static IQueryable GroupJoin(this IQueryable outer, IEnumerable inner, Expression> outerKeySelector, Expression> innerKeySelector, Expression, TResult>> resultSelector) { diff --git a/src/libraries/System.Linq.Queryable/tests/GroupJoinTests.cs b/src/libraries/System.Linq.Queryable/tests/GroupJoinTests.cs index 3db5cb8cad203e..8d79f7e1c17919 100644 --- a/src/libraries/System.Linq.Queryable/tests/GroupJoinTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/GroupJoinTests.cs @@ -287,5 +287,67 @@ public void GroupJoin2() var count = new[] { 0, 1, 2 }.AsQueryable().GroupJoin(new[] { 1, 2, 3 }, n1 => n1, n2 => n2, (n1, n2) => n1, EqualityComparer.Default).Count(); Assert.Equal(3, count); } + + [Fact] + public void GroupJoinWithoutResultSelector() + { + var result = new[] { 0, 1, 2 }.AsQueryable().GroupJoin(new[] { 1, 2, 3 }, n1 => n1, n2 => n2).ToList(); + Assert.Equal(3, result.Count); + Assert.Equal(0, result[0].Key); + Assert.Empty(result[0]); + Assert.Equal(1, result[1].Key); + Assert.Single(result[1]); + Assert.Equal(2, result[2].Key); + Assert.Single(result[2]); + } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterNull() + { + IQueryable outer = null; + int[] inner = { 1, 2, 3 }; + + AssertExtensions.Throws("outer", () => outer.GroupJoin(inner.AsQueryable(), n1 => n1, n2 => n2)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_InnerNull() + { + int[] outer = { 0, 1, 2 }; + IQueryable inner = null; + + AssertExtensions.Throws("inner", () => outer.AsQueryable().GroupJoin(inner, n1 => n1, n2 => n2)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterKeySelectorNull() + { + int[] outer = { 0, 1, 2 }; + int[] inner = { 1, 2, 3 }; + + AssertExtensions.Throws("outerKeySelector", () => outer.AsQueryable().GroupJoin(inner.AsQueryable(), null, n2 => n2)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_InnerKeySelectorNull() + { + int[] outer = { 0, 1, 2 }; + int[] inner = { 1, 2, 3 }; + + AssertExtensions.Throws("innerKeySelector", () => outer.AsQueryable().GroupJoin(inner.AsQueryable(), n1 => n1, null)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_CustomComparer() + { + var result = new[] { "Tim", "Bob", "Robert" }.AsQueryable().GroupJoin(new[] { "miT", "Robert" }, n1 => n1, n2 => n2, new AnagramEqualityComparer()).ToList(); + Assert.Equal(3, result.Count); + Assert.Equal("Tim", result[0].Key); + Assert.Single(result[0]); + Assert.Equal("Bob", result[1].Key); + Assert.Empty(result[1]); + Assert.Equal("Robert", result[2].Key); + Assert.Single(result[2]); + } } } diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index 0c2d82b2111766..09bbb07c5884a7 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -81,6 +81,8 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func elementSelector, System.Func, TResult> resultSelector) { throw null; } public static System.Collections.Generic.IEnumerable GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func elementSelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } + public static System.Collections.Generic.IEnumerable> GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector) { throw null; } + public static System.Collections.Generic.IEnumerable> GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector) { throw null; } public static System.Collections.Generic.IEnumerable GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable<(int Index, TSource Item)> Index(this System.Collections.Generic.IEnumerable source) { throw null; } diff --git a/src/libraries/System.Linq/src/System/Linq/GroupJoin.cs b/src/libraries/System.Linq/src/System/Linq/GroupJoin.cs index b59ca859ec568c..d50958b8615b50 100644 --- a/src/libraries/System.Linq/src/System/Linq/GroupJoin.cs +++ b/src/libraries/System.Linq/src/System/Linq/GroupJoin.cs @@ -1,12 +1,46 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections; using System.Collections.Generic; namespace System.Linq { public static partial class Enumerable { + public static IEnumerable> GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector) => + GroupJoin(outer, inner, outerKeySelector, innerKeySelector, comparer: null); + + public static IEnumerable> GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, IEqualityComparer? comparer) + { + if (outer is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.outer); + } + + if (inner is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.inner); + } + + if (outerKeySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.outerKeySelector); + } + + if (innerKeySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.innerKeySelector); + } + + if (IsEmptyArray(outer)) + { + return []; + } + + return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, comparer); + } + public static IEnumerable GroupJoin(this IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector) => GroupJoin(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer: null); @@ -45,6 +79,21 @@ public static IEnumerable GroupJoin(this return GroupJoinIterator(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); } + private static IEnumerable> GroupJoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, IEqualityComparer? comparer) + { + using IEnumerator e = outer.GetEnumerator(); + if (e.MoveNext()) + { + Lookup lookup = Lookup.CreateForJoin(inner, innerKeySelector, comparer); + do + { + TOuter item = e.Current; + yield return new GroupJoinGrouping(item, lookup[outerKeySelector(item)]); + } + while (e.MoveNext()); + } + } + private static IEnumerable GroupJoinIterator(IEnumerable outer, IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer? comparer) { using IEnumerator e = outer.GetEnumerator(); @@ -60,4 +109,22 @@ private static IEnumerable GroupJoinIterator : IGrouping + { + private readonly TKey _key; + private readonly IEnumerable _elements; + + public GroupJoinGrouping(TKey key, IEnumerable elements) + { + _key = key; + _elements = elements; + } + + public TKey Key => _key; + + public IEnumerator GetEnumerator() => _elements.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } } diff --git a/src/libraries/System.Linq/tests/GroupJoinTests.cs b/src/libraries/System.Linq/tests/GroupJoinTests.cs index 833ada49a5d89c..9144b718423333 100644 --- a/src/libraries/System.Linq/tests/GroupJoinTests.cs +++ b/src/libraries/System.Linq/tests/GroupJoinTests.cs @@ -513,5 +513,179 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator>; Assert.False(en is not null && en.MoveNext()); } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterEmptyInnerNonEmpty() + { + CustomerRec[] outer = []; + OrderRec[] inner = + [ + new OrderRec{ orderID = 45321, custID = 98022, total = 50 }, + new OrderRec{ orderID = 97865, custID = 32103, total = 25 } + ]; + Assert.Empty(outer.GroupJoin(inner, e => e.custID, e => e.custID)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterNonEmptyInnerEmpty() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 43434 }, + new CustomerRec{ name = "Bob", custID = 34093 } + ]; + OrderRec[] inner = []; + + var result = outer.GroupJoin(inner, e => e.custID, e => e.custID).ToList(); + + Assert.Equal(2, result.Count); + Assert.Equal(outer[0], result[0].Key); + Assert.Empty(result[0]); + Assert.Equal(outer[1], result[1].Key); + Assert.Empty(result[1]); + } + + [Fact] + public void GroupJoinWithoutResultSelector_SingleElementEachAndMatches() + { + CustomerRec[] outer = [new CustomerRec{ name = "Tim", custID = 43434 }]; + OrderRec[] inner = [new OrderRec{ orderID = 97865, custID = 43434, total = 25 }]; + + var result = outer.GroupJoin(inner, e => e.custID, e => e.custID).ToList(); + + Assert.Single(result); + Assert.Equal(outer[0], result[0].Key); + Assert.Single(result[0]); + Assert.Equal(inner[0], result[0].First()); + } + + [Fact] + public void GroupJoinWithoutResultSelector_InnerSameKeyMoreThanOneElementAndMatches() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 }, + new CustomerRec{ name = "Bob", custID = 9865 } + ]; + OrderRec[] inner = + [ + new OrderRec{ orderID = 97865, custID = 1234, total = 25 }, + new OrderRec{ orderID = 34390, custID = 1234, total = 19 }, + new OrderRec{ orderID = 34390, custID = 9865, total = 19 } + ]; + + var result = outer.GroupJoin(inner, e => e.custID, e => e.custID).ToList(); + + Assert.Equal(2, result.Count); + Assert.Equal(outer[0], result[0].Key); + Assert.Equal(2, result[0].Count()); + Assert.Equal(outer[1], result[1].Key); + Assert.Single(result[1]); + } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterNull() + { + CustomerRec[] outer = null; + OrderRec[] inner = + [ + new OrderRec{ orderID = 45321, custID = 98022, total = 50 } + ]; + + AssertExtensions.Throws("outer", () => outer.GroupJoin(inner, e => e.custID, e => e.custID)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_InnerNull() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 } + ]; + OrderRec[] inner = null; + + AssertExtensions.Throws("inner", () => outer.GroupJoin(inner, e => e.custID, e => e.custID)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_OuterKeySelectorNull() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 } + ]; + OrderRec[] inner = + [ + new OrderRec{ orderID = 45321, custID = 98022, total = 50 } + ]; + + AssertExtensions.Throws("outerKeySelector", () => outer.GroupJoin(inner, null, e => e.custID)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_InnerKeySelectorNull() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 } + ]; + OrderRec[] inner = + [ + new OrderRec{ orderID = 45321, custID = 98022, total = 50 } + ]; + + AssertExtensions.Throws("innerKeySelector", () => outer.GroupJoin(inner, e => e.custID, null)); + } + + [Fact] + public void GroupJoinWithoutResultSelector_CanIterateMultipleTimes() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 } + ]; + OrderRec[] inner = + [ + new OrderRec{ orderID = 97865, custID = 1234, total = 25 } + ]; + + var result = outer.GroupJoin(inner, e => e.custID, e => e.custID).ToList(); + + Assert.Single(result); + + // Iterate the grouped elements multiple times + Assert.Single(result[0]); + Assert.Single(result[0]); + Assert.Equal(inner[0], result[0].First()); + Assert.Equal(inner[0], result[0].First()); + } + + [Fact] + public void GroupJoinWithoutResultSelector_CustomComparer() + { + CustomerRec[] outer = + [ + new CustomerRec{ name = "Tim", custID = 1234 }, + new CustomerRec{ name = "Bob", custID = 9865 }, + new CustomerRec{ name = "Robert", custID = 9895 } + ]; + AnagramRec[] inner = + [ + new AnagramRec{ name = "Robert", orderID = 93483, total = 19 }, + new AnagramRec{ name = "miT", orderID = 93489, total = 45 } + ]; + + var result = outer.GroupJoin(inner, e => e.name, e => e.name, new AnagramEqualityComparer()).ToList(); + + Assert.Equal(3, result.Count); + Assert.Equal(outer[0], result[0].Key); + Assert.Single(result[0]); + Assert.Equal(inner[1], result[0].First()); + Assert.Equal(outer[1], result[1].Key); + Assert.Empty(result[1]); + Assert.Equal(outer[2], result[2].Key); + Assert.Single(result[2]); + Assert.Equal(inner[0], result[2].First()); + } } }