diff --git a/src/main/java/com/pivovarit/gatherers/MoreGatherers.java b/src/main/java/com/pivovarit/gatherers/MoreGatherers.java index 342c6ab..a4bc934 100644 --- a/src/main/java/com/pivovarit/gatherers/MoreGatherers.java +++ b/src/main/java/com/pivovarit/gatherers/MoreGatherers.java @@ -19,6 +19,23 @@ public final class MoreGatherers { private MoreGatherers() { } + public static Gatherer sample(int n) { + if (n <= 0) { + throw new IllegalArgumentException("sample size can't be lower than 1"); + } + return n == 1 + ? noop() + : Gatherer.ofSequential( + () -> new AtomicLong(), + (state, element, downstream) -> { + if (state.getAndIncrement() % n == 0) { + downstream.push(element); + } + return true; + } + ); + } + public static Gatherer distinctBy(Function keyExtractor) { Objects.requireNonNull(keyExtractor); return Gatherer.ofSequential( @@ -44,7 +61,7 @@ private MoreGatherers() { return zip(other.iterator()); } - public static Gatherer zip(Collection other, BiFunction mapper) { + public static Gatherer zip(Collection other, BiFunction mapper) { return zip(other.iterator(), mapper); } @@ -75,4 +92,14 @@ private MoreGatherers() { }) ); } + + static Gatherer noop() { + return ofSequential( + () -> null, + (_, element, downstream) -> { + downstream.push(element); + return true; + } + ); + } } diff --git a/src/test/java/com/pivovarit/gatherers/SampleTest.java b/src/test/java/com/pivovarit/gatherers/SampleTest.java new file mode 100644 index 0000000..43cff54 --- /dev/null +++ b/src/test/java/com/pivovarit/gatherers/SampleTest.java @@ -0,0 +1,34 @@ +package com.pivovarit.gatherers; + +import org.junit.jupiter.api.Test; + +import java.util.stream.Stream; + +import static com.pivovarit.gatherers.MoreGatherers.sample; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class SampleTest { + + @Test + void shouldRejectInvalidSampleSize() { + assertThatThrownBy(() -> sample(0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("sample size can't be lower than 1"); + } + + @Test + void shouldSampleEmpty() throws Exception { + assertThat(Stream.empty().gather(sample(42))).isEmpty(); + } + + @Test + void shouldSampleEvery() { + assertThat(Stream.of(1,2,3).gather(sample(1))).containsExactly(1,2,3); + } + + @Test + void shouldSampleEveryOther() { + assertThat(Stream.of(1,2,3).gather(sample(2))).containsExactly(1,3); + } +}