Skip to content

Commit d1b8dcf

Browse files
authored
Raise error in Choice on duplicate outlets (#545)
* Raise error in `Choice` on duplicate outlets * Move validation * Add tests, improve error messages
1 parent 22765a0 commit d1b8dcf

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

storey/flow.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from asyncio import Task
2424
from collections import defaultdict
2525
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
26-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
26+
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union
2727

2828
import aiohttp
2929

@@ -363,12 +363,12 @@ def _init(self):
363363
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
364364
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]
365365

366-
def select_outlets(self, event) -> List[str]:
366+
def select_outlets(self, event) -> Collection[str]:
367367
"""
368368
Override this method to route events based on a customer logic. The default implementation will route all
369369
events to all outlets.
370370
"""
371-
return list(self._name_to_outlet.keys())
371+
return self._name_to_outlet.keys()
372372

373373
async def _do(self, event):
374374
if event is _termination_obj:
@@ -381,6 +381,11 @@ async def _do(self, event):
381381
outlet = self._name_to_outlet["dataframe"]
382382
outlets.append(outlet)
383383
else:
384+
if len(set(outlet_names)) != len(outlet_names):
385+
raise ValueError(
386+
"select_outlets() returned duplicate outlets among the defined outlets: "
387+
+ ", ".join(outlet_names)
388+
)
384389
for outlet_name in outlet_names:
385390
if outlet_name not in self._name_to_outlet:
386391
raise ValueError(

tests/test_flow.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,50 @@ def select_outlets(self, event):
17281728
assert termination_result == expected
17291729

17301730

1731+
def test_duplicate_choice():
1732+
class DuplicateChoice(Choice):
1733+
def select_outlets(self, event):
1734+
outlets = ["all_events", "all_events"]
1735+
return outlets
1736+
1737+
source = SyncEmitSource()
1738+
duplicate_choice = DuplicateChoice(termination_result_fn=lambda x, y: x + y)
1739+
all_events = Map(lambda x: x, name="all_events")
1740+
1741+
source.to(duplicate_choice).to(all_events)
1742+
1743+
controller = source.run()
1744+
controller.emit(0)
1745+
controller.terminate()
1746+
with pytest.raises(
1747+
ValueError,
1748+
match=r"select_outlets\(\) returned duplicate outlets among the defined outlets: all_events, all_events",
1749+
):
1750+
controller.await_termination()
1751+
1752+
1753+
def test_nonexistent_choice():
1754+
class NonexistentChoice(Choice):
1755+
def select_outlets(self, event):
1756+
outlets = ["wrong"]
1757+
return outlets
1758+
1759+
source = SyncEmitSource()
1760+
nonexistent_choice = NonexistentChoice(termination_result_fn=lambda x, y: x + y)
1761+
all_events = Map(lambda x: x, name="all_events")
1762+
1763+
source.to(nonexistent_choice).to(all_events)
1764+
1765+
controller = source.run()
1766+
controller.emit(0)
1767+
controller.terminate()
1768+
with pytest.raises(
1769+
ValueError,
1770+
match=r"select_outlets\(\) returned outlet name 'wrong', which is not one of the defined outlets: all_events",
1771+
):
1772+
controller.await_termination()
1773+
1774+
17311775
def test_metadata():
17321776
def mapf(x):
17331777
x.key = x.key + 1

0 commit comments

Comments
 (0)