Skip to content

Commit

Permalink
Add MoreGatherers.sample(int n)
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit committed Oct 9, 2024
1 parent b1fb0b4 commit 2d94907
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/main/java/com/pivovarit/gatherers/MoreGatherers.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ public final class MoreGatherers {
private MoreGatherers() {
}

public static <T> Gatherer<T, ?, T> sample(int n) {
if (n <= 0) {
throw new IllegalArgumentException("sample size can't be lower than 1");
}
return n == 1
? noop()
: Gatherer.ofSequential(
() -> new AtomicLong(),
(state, element, downstream) -> {
if (state.getAndIncrement() % n == 0) {
downstream.push(element);
}
return true;
}
);
}

public static <T, U> Gatherer<T, ?, T> distinctBy(Function<? super T, ? extends U> keyExtractor) {
Objects.requireNonNull(keyExtractor);
return Gatherer.ofSequential(
Expand All @@ -44,7 +61,7 @@ private MoreGatherers() {
return zip(other.iterator());
}

public static <T1, T2, R> Gatherer<T1, ?, R> zip(Collection<T2> other, BiFunction<? super T1, ? super T2, ? extends R> mapper) {
public static <T1, T2, R> Gatherer<T1, ?, R> zip(Collection<T2> other, BiFunction<? super T1, ? super T2, ? extends R> mapper) {
return zip(other.iterator(), mapper);
}

Expand Down Expand Up @@ -75,4 +92,14 @@ private MoreGatherers() {
})
);
}

static <T> Gatherer<T, ?, T> noop() {
return ofSequential(
() -> null,
(_, element, downstream) -> {
downstream.push(element);
return true;
}
);
}
}
34 changes: 34 additions & 0 deletions src/test/java/com/pivovarit/gatherers/SampleTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.pivovarit.gatherers;

import org.junit.jupiter.api.Test;

import java.util.stream.Stream;

import static com.pivovarit.gatherers.MoreGatherers.sample;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class SampleTest {

@Test
void shouldRejectInvalidSampleSize() {
assertThatThrownBy(() -> sample(0))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("sample size can't be lower than 1");
}

@Test
void shouldSampleEmpty() throws Exception {
assertThat(Stream.empty().gather(sample(42))).isEmpty();
}

@Test
void shouldSampleEvery() {
assertThat(Stream.of(1,2,3).gather(sample(1))).containsExactly(1,2,3);
}

@Test
void shouldSampleEveryOther() {
assertThat(Stream.of(1,2,3).gather(sample(2))).containsExactly(1,3);
}
}

0 comments on commit 2d94907

Please sign in to comment.