Skip to content

Commit

Permalink
[proxima-beam-core] #339 expander fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
je-ik committed Oct 24, 2024
1 parent d4656b6 commit a1e4b93
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ PTransform<PCollection<InputT>, PCollectionTuple> transformedParDo(
return new PTransform<>() {
@Override
public PCollectionTuple expand(PCollection<InputT> input) {
Coder<InputT> inputCoder = input.getCoder();
@SuppressWarnings("unchecked")
KvCoder<K, V> coder = (KvCoder<K, V>) input.getCoder();
KvCoder<K, V> coder = (KvCoder<K, V>) inputCoder;
Coder<K> keyCoder = coder.getKeyCoder();
Coder<V> valueCoder = coder.getValueCoder();
TypeDescriptor<StateOrInput<V>> valueDescriptor =
Expand Down Expand Up @@ -413,7 +414,7 @@ public PCollectionTuple expand(PCollection<InputT> input) {
PCollectionTuple tuple =
flattened.apply(
"expanded",
ParDo.of(transformedDoFn(doFn, (KvCoder<K, V>) input.getCoder(), mainOutputTag))
ParDo.of(transformedDoFn(doFn, (KvCoder<K, V>) inputCoder, mainOutputTag))
.withOutputTags(mainOutputTag, otherOutputs.and(STATE_TUPLE_TAG)));
PCollectionTuple res = PCollectionTuple.empty(input.getPipeline());
for (Entry<TupleTag<Object>, PCollection<Object>> e :
Expand Down Expand Up @@ -946,7 +947,7 @@ public void intercept(@This DoFn<KV<V, StateOrInput<V>>, ?> doFn, @AllArguments
boolean isNextScheduled =
nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE);
if (isNextScheduled) {
flushTimer.set(nextFlush);
flushTimer.withOutputTimestamp(nextFlush).set(nextFlush);
nextFlushState.write(nextFlush);
}
@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,18 @@ static Map<String, BiConsumer<Object, StateValue>> getStateUpdaters(DoFn<?, ?> d
Pair.of(
p.getSecond().value(),
createUpdater(
p.getSecond().value(),
((StateSpec<?>)
ExceptionUtils.uncheckedFactory(() -> p.getFirst().get(doFn))))))
.filter(p -> p.getSecond() != null)
.collect(Collectors.toMap(Pair::getFirst, Pair::getSecond));
}

@SuppressWarnings("unchecked")
private static @Nullable BiConsumer<Object, StateValue> createUpdater(StateSpec<?> stateSpec) {
private static @Nullable BiConsumer<Object, StateValue> createUpdater(
String name, StateSpec<?> stateSpec) {

AtomicReference<BiConsumer<Object, StateValue>> consumer = new AtomicReference<>();
stateSpec.bind("dummy", createUpdaterBinder(consumer));
stateSpec.bind(name, createUpdaterBinder(consumer));
return consumer.get();
}

Expand All @@ -363,6 +365,7 @@ static LinkedHashMap<String, BiFunction<Object, byte[], Iterable<StateValue>>> g
Pair.of(
p.getSecond().value(),
createReader(
p.getSecond().value(),
((StateSpec<?>)
ExceptionUtils.uncheckedFactory(() -> p.getFirst().get(doFn))))))
.filter(p -> p.getSecond() != null)
Expand All @@ -372,9 +375,10 @@ static LinkedHashMap<String, BiFunction<Object, byte[], Iterable<StateValue>>> g

@SuppressWarnings("unchecked")
private static @Nullable BiFunction<Object, byte[], Iterable<StateValue>> createReader(
StateSpec<?> stateSpec) {
String name, StateSpec<?> stateSpec) {

AtomicReference<BiFunction<Object, byte[], Iterable<StateValue>>> res = new AtomicReference<>();
stateSpec.bind("dummy", createStateReaderBinder(res));
stateSpec.bind(name, createStateReaderBinder(res));
return res.get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ public void testSimpleExpandWithStateStore() throws IOException {
(int)
CoderUtils.decodeFromByteArray(
VarIntCoder.of(), second.getValue().getValue().getKey()));
assertEquals("sum", second.getValue().getValue().getName());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ public void testTransactionCreateCommit() throws InterruptedException {
}

@Test(timeout = 10_000)
public void testTransactionCreateUpdateCommitMultipleOutputs()
throws InterruptedException, TransactionRejectedException {

public void testTransactionCreateUpdateCommitMultipleOutputs() throws InterruptedException {
CachedView view = Optionals.get(direct.getCachedView(status));
view.assign(view.getPartitions());
OnlineAttributeWriter writer = Optionals.get(direct.getWriter(status));
Expand Down Expand Up @@ -348,9 +346,7 @@ public void testTransactionCommitReject() throws InterruptedException {
}

@Test(timeout = 10000)
public void testGlobalTransactionWriter()
throws InterruptedException, TransactionRejectedException {

public void testGlobalTransactionWriter() throws InterruptedException {
TransactionalOnlineAttributeWriter writer = direct.getGlobalTransactionWriter();
assertTrue(user.isTransactional());
// we successfully open and commit the transaction
Expand Down

0 comments on commit a1e4b93

Please sign in to comment.