|
71 | 71 | build_flow,
|
72 | 72 | )
|
73 | 73 | from storey.flow import (
|
| 74 | + ConcurrentExecution, |
74 | 75 | Context,
|
75 | 76 | ParallelExecution,
|
76 | 77 | ParallelExecutionRunnable,
|
@@ -360,6 +361,50 @@ def test_async_offset_commit_before_termination_with_nosqltarget():
|
360 | 361 | asyncio.run(async_offset_commit_before_termination_with_nosqltarget())
|
361 | 362 |
|
362 | 363 |
|
| 364 | +async def async_offset_commit_before_termination_with_concurrent_execution(): |
| 365 | + platform = Committer() |
| 366 | + context = CommitterContext(platform) |
| 367 | + |
| 368 | + max_wait_before_commit = 1 |
| 369 | + |
| 370 | + controller = build_flow( |
| 371 | + [ |
| 372 | + AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=max_wait_before_commit), |
| 373 | + ConcurrentExecution(event_processor=lambda x: x + 1), |
| 374 | + Filter(lambda x: x < 3), |
| 375 | + FlatMap(lambda x: [x, x * 10]), |
| 376 | + Reduce(0, lambda acc, x: acc + x), |
| 377 | + ] |
| 378 | + ).run() |
| 379 | + |
| 380 | + num_shards = 10 |
| 381 | + num_records_per_shard = 10 |
| 382 | + |
| 383 | + for offset in range(1, num_records_per_shard + 1): |
| 384 | + for shard in range(num_shards): |
| 385 | + event = Event(shard) |
| 386 | + event.shard_id = shard |
| 387 | + event.offset = offset |
| 388 | + await controller.emit(event) |
| 389 | + |
| 390 | + del event |
| 391 | + |
| 392 | + await asyncio.sleep(max_wait_before_commit + 1) |
| 393 | + |
| 394 | + try: |
| 395 | + offsets = copy.copy(platform.offsets) |
| 396 | + assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)} |
| 397 | + finally: |
| 398 | + await controller.terminate() |
| 399 | + termination_result = await controller.await_termination() |
| 400 | + assert termination_result == 330 |
| 401 | + |
| 402 | + |
| 403 | +# ML-8799 |
| 404 | +def test_async_offset_commit_before_termination_with_concurrent_execution(): |
| 405 | + asyncio.run(async_offset_commit_before_termination_with_concurrent_execution()) |
| 406 | + |
| 407 | + |
363 | 408 | def test_offset_not_committed_prematurely():
|
364 | 409 | platform = Committer()
|
365 | 410 | context = CommitterContext(platform)
|
|
0 commit comments