Skip to content

Commit 4fc0bc2

Browse files
authored
Improve function name retrieval (#329)
1 parent 84c2f62 commit 4fc0bc2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/spdl/pipeline/_builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AsyncIterable,
1616
AsyncIterator,
1717
Awaitable,
18+
Callable,
1819
Coroutine,
1920
Iterable,
2021
Sequence,
@@ -37,7 +38,7 @@
3738
from ._pipeline import Pipeline
3839
from ._utils import create_task
3940

40-
__all__ = ["PipelineFailure", "PipelineBuilder"]
41+
__all__ = ["PipelineFailure", "PipelineBuilder", "_get_op_name"]
4142

4243
_LG = logging.getLogger(__name__)
4344

@@ -115,6 +116,12 @@ async def _put_eof_when_done(queue):
115116
################################################################################
116117

117118

119+
def _get_op_name(op: Callable) -> str:
120+
if isinstance(op, partial):
121+
return _get_op_name(op.func)
122+
return getattr(op, "__name__", op.__class__.__name__)
123+
124+
118125
def _pipe(
119126
input_queue: AsyncQueue[T],
120127
op: Callables[T, U],
@@ -619,7 +626,7 @@ def pipe(
619626
"when `output_order` is 'input'."
620627
)
621628

622-
name = name or getattr(op, "__name__", op.__class__.__name__)
629+
name = name or _get_op_name(op)
623630

624631
if kwargs:
625632
# pyre-ignore

0 commit comments

Comments
 (0)