Skip to content

Commit

Permalink
initial draft of synchronous pipeline runner
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Jan 24, 2025
1 parent 643a7f0 commit 209aaed
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 96 deletions.
6 changes: 3 additions & 3 deletions mio/data/config/wirefree/wirefree-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ nodes:
path: sd_path
outputs:
- source: header
target: merge
target: merge.header
- source: buffer
target: merge
target: merge.buffer
merge:
type: "merge-buffers"
fill:
width: file.width
height: file.height
outputs:
- source: frame
target: data
target: return
return:
config:
key: frame
Expand Down
142 changes: 111 additions & 31 deletions mio/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import sys
from abc import abstractmethod
from datetime import datetime
from graphlib import TopologicalSorter
from typing import Any, ClassVar, Final, Generic, Optional, TypedDict, TypeVar, Union, final
from typing import Any, ClassVar, Generic, Optional, TypedDict, TypeVar, Union, Unpack, final

from pydantic import Field, field_validator, model_validator

Expand All @@ -21,12 +22,29 @@
"""
Input Type typevar
"""
U = TypeVar("U")
U = TypeVar("U", bound=dict[str, Any])
"""
Output Type typevar
"""


class Event(TypedDict, Generic[U]):
"""
Container for a single value returned from a single :meth:`.Node.process` call
"""

id: int
"""Unique ID for each event"""
timestamp: datetime
"""Timestamp of when the event was received by the :class:`.PipelineRunner"""
node_id: str
"""ID of node that emitted the event"""
slot: str
"""name of the slot that emitted the event"""
value: Any
"""Value emitted by the processing node"""


class _NodeMap(TypedDict):
source: str
target: str
Expand Down Expand Up @@ -231,9 +249,7 @@ class Node(PipelineModel, Generic[T, U]):
config: Optional[NodeConfig] = None

input_type: ClassVar[type[T]]
inputs: dict[str, Union["Source", "Transform"]] = Field(default_factory=dict)
output_type: ClassVar[type[U]]
outputs: dict[str, Union["Sink", "Transform"]] = Field(default_factory=dict)

def start(self) -> None:
"""
Expand All @@ -253,6 +269,11 @@ def stop(self) -> None:
"""
pass

@abstractmethod
def process(self, **kwargs: Unpack[T]) -> Optional[U]:
"""Process some input, emitting it. See subclasses for details"""
pass

@classmethod
def from_specification(cls, config: NodeSpecification) -> Self:
"""
Expand Down Expand Up @@ -283,7 +304,6 @@ def node_types(cls) -> dict[str, type["Node"]]:
class Source(Node, Generic[T, U]):
"""A source of data in a processing pipeline"""

inputs: Final[None] = None
input_type: ClassVar[None] = None

@abstractmethod
Expand All @@ -305,10 +325,9 @@ class Sink(Node, Generic[T, U]):
"""A sink of data in a processing pipeline"""

output_type: ClassVar[None] = None
outputs: Final[None] = None

@abstractmethod
def process(self, data: T) -> None:
def process(self, **kwargs: Unpack[T]) -> None:
"""
Process some incoming data, returning None
Expand All @@ -326,7 +345,7 @@ class Transform(Node, Generic[T, U]):
"""

@abstractmethod
def process(self, data: T) -> U:
def process(self, **kwargs: Unpack[T]) -> U:
"""
Process some incoming data, yielding a transformed output
Expand All @@ -339,6 +358,17 @@ def process(self, data: T) -> U:
"""


class Edge(PipelineModel):
"""
Directed connection between an output slot a node and an input slot in another node
"""

source_node: Node
source_slot: Optional[str] = None
target_node: Node
target_slot: Optional[str] = None


class Pipeline(PipelineModel):
"""
A graph of nodes transforming some input source(s) to some output sink(s)
Expand All @@ -352,6 +382,14 @@ class Pipeline(PipelineModel):
"""
Dictionary mapping all nodes from their ID to the instantiated node.
"""
edges: list[Edge] = Field(default_factory=list)
"""
Edges connecting slots within nodes.
The nodes within :attr:`.Edge.source_node` and :attr:`.Edge.target_node` must
be the same objects as those in :attr:`.Pipeline.nodes`
(i.e. ``edges[0].source_node is nodes[node_id]`` ).
"""

@property
def sources(self) -> dict[str, "Source"]:
Expand All @@ -368,6 +406,45 @@ def sinks(self) -> dict[str, "Sink"]:
"""All :class:`.Sink` nodes in the processing graph"""
return {k: v for k, v in self.nodes.items() if isinstance(v, Sink)}

def graph(self) -> TopologicalSorter:
"""
Produce a :class:`.TopologicalSorter` based on the graph induced by
:attr:`.Pipeline.nodes` and :attr:`.Pipeline.edges` that yields node ids
"""
sorter = TopologicalSorter()
for node_id, node in self.nodes.items():
in_edges = [e.target_node.id for e in self.edges if e.target_node is node]
sorter.add(node_id, *in_edges)
return sorter

def in_edges(self, node: Union[Node, str]) -> list[Edge]:
"""
Edges going towards the given node (i.e. the node is the edge's ``target`` )
Args:
node (:class:`.Node`, str): Either a node or its id
Returns:
list[:class:`.Edge`]
"""
if isinstance(node, Node):
node = node.id
return [e for e in self.edges if e.target_node.id == node]

def out_edges(self, node: Union[Node, str]) -> list[Edge]:
"""
Edges going away from the given node (i.e. the node is the edge's ``source`` )
Args:
node (:class:`.Node`, str): Either a node or its id
Returns:
list[:class:`.Edge`]
"""
if isinstance(node, Node):
node = node.id
return [e for e in self.edges if e.source_node.id == node]

@classmethod
def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] = None) -> Self:
"""
Expand All @@ -380,8 +457,9 @@ def from_config(cls, config: PipelineConfig, passed: Optional[dict[str, Any]] =
cls._validate_passed(config, passed)

nodes = cls._init_nodes(config, passed)
edges = cls._init_edges(nodes, config.nodes)

return cls(nodes=nodes)
return cls(nodes=nodes, edges=edges)

@classmethod
def passed_values(cls, config: PipelineConfig) -> dict[str, type]:
Expand Down Expand Up @@ -435,6 +513,30 @@ def _init_nodes(

return nodes

@classmethod
def _init_edges(cls, nodes: dict[str, Node], spec: dict[str, NodeSpecification]) -> list[Edge]:
edges = []
for node_id, node_spec in spec.items():
if not node_spec.outputs:
continue
for output in node_spec.outputs:
# FIXME: Ugly and not DRY
target_parts = output["target"].split(".")
target_id, target_slot = (
(target_parts[0], target_parts[1])
if len(target_parts) == 2
else (target_parts[0], None)
)
edges.append(
Edge(
source_node=nodes[node_id],
target_node=nodes[target_id],
source_slot=output["source"],
target_slot=target_slot,
)
)
return edges

@classmethod
def _complete_node(
cls, node: NodeSpecification, context: dict[str, Node], passed: dict
Expand Down Expand Up @@ -471,25 +573,3 @@ def _validate_passed(cls, config: PipelineConfig, passed: dict[str, Any]) -> Non
f"But received passed values:\n"
f"{passed}"
)


def connect_nodes(nodes: dict[str, Node]) -> dict[str, Node]:
"""
Provide references to instantiated nodes
"""

for node in nodes.values():
if node.config.inputs and node.inputs is None:
raise ConfigurationMismatchError(
"inputs found in node configuration, but node type allows no inputs!\n"
f"node: {node.model_dump()}"
)
if node.config.outputs and not hasattr(node, "outputs"):
raise ConfigurationMismatchError(
"outputs found in node configuration, but node type allows no outputs!\n"
f"node: {node.model_dump()}"
)

node.inputs.update({id: nodes[id] for id in node.config.inputs})
node.outputs.update({id: nodes[id] for id in node.config.outputs})
return nodes
Loading

0 comments on commit 209aaed

Please sign in to comment.