Skip to content

Commit a586839

Browse files
authored
425 fix progress bars on jupyter (#426)
* fix(events): Add sync handlers in event dispatcher - show_progress as sync to fix broken progress bars in jupyter environments - Introduce HandlerMetadata with sync flag in - registry Allow decorator usage with - @..._event_handler(sync=True|False) Execute sync - handlers on calling thread; queue only async - ones Submit only async handlers to thread pool; - adjust shutdown Add unregister_handler API and - update tests accordingly Mark progress - test verifying sync vs async execution threads * feat(events): Emit COMMAND_FINISHED and reset progress - Introduce EventType.COMMAND_FINISHED - Emit command.finished when engine signals exit Stop and recreate Progress on command finish - clear bars Reconstruct renderable only in textual output to avoid control chars - Force monochrome output in VSCode terminals - Rename new_task_with_reconstructed_renderable to new_task_and_reconstruct_renderable_maybe * Fix union type in handler; use sum generic in test_events
1 parent a96d001 commit a586839

File tree

6 files changed

+179
-41
lines changed

6 files changed

+179
-41
lines changed

src/python-api/getml/communication.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@
4343
PoolEventDispatcher,
4444
)
4545
from getml.events.types import (
46+
Event,
4647
EventContext,
4748
EventEmitter,
4849
EventParser,
4950
EventSource,
51+
EventType,
5052
)
5153
from getml.exceptions import handle_engine_exception
5254
from getml.helpers import _is_iterable_not_str
@@ -95,6 +97,12 @@ def listen(self, socket: socket.socket, exit_on: Callable[[str], bool]) -> str:
9597
self.emitter.emit(events)
9698

9799
if exit_on(exit_status := message):
100+
finished_event = Event(
101+
source=self.parser.context.source,
102+
type=EventType.COMMAND_FINISHED,
103+
attributes={},
104+
)
105+
self.emitter.emit([finished_event])
98106
return exit_status
99107

100108

src/python-api/getml/events/dispatchers.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import warnings
1717
from concurrent.futures import Executor, Future, ThreadPoolExecutor
1818
from types import TracebackType
19-
from typing import Callable, Dict, Optional, Sequence, Type, Union
19+
from typing import Dict, Optional, Sequence, Type, Union
2020

21-
from getml.events.handlers import EventHandlerRegistry
21+
from getml.events.handlers import EventHandlerRegistry, HandlerMetadata
2222
from getml.events.types import Event, EventSource, Shutdown
2323

2424

@@ -40,8 +40,8 @@ def __init__(
4040
):
4141
self.queues = {
4242
source: {
43-
handler: queue.Queue()
44-
for handler in EventHandlerRegistry.handlers[source]
43+
metadata: queue.Queue()
44+
for metadata in EventHandlerRegistry.handlers[source]
4545
}
4646
for source in EventSource
4747
}
@@ -50,7 +50,7 @@ def __init__(
5050
executor = ThreadPoolExecutor()
5151
self.executor = executor
5252

53-
self.futures: Dict[Callable[[Event], None], Future] = {}
53+
self.futures: Dict[HandlerMetadata, Future] = {}
5454

5555
def __enter__(self):
5656
self.start()
@@ -69,12 +69,23 @@ def __exit__(
6969

7070
def dispatch(self, events: Sequence[Event]):
7171
for event in events:
72-
for queue in self.queues[event.source].values():
73-
queue.put(event)
72+
for metadata, handler_queue in self.queues[event.source].items():
73+
if metadata.sync:
74+
# Execute synchronously in the calling thread
75+
try:
76+
metadata.handler(event) # type: ignore
77+
except Exception as e:
78+
warnings.warn(
79+
f"An exception occurred while dispatching event {event} to handler {metadata.handler}: {e}",
80+
RuntimeWarning,
81+
)
82+
else:
83+
# Dispatch to worker thread via queue
84+
handler_queue.put(event)
7485

7586
def process(
7687
self,
77-
handler: Callable[[Event], None],
88+
metadata: HandlerMetadata,
7889
handler_queue: queue.Queue[Union[Event, Shutdown]],
7990
):
8091
while True:
@@ -88,29 +99,33 @@ def process(
8899
return
89100

90101
try:
91-
handler(event) # type: ignore
102+
metadata.handler(event) # type: ignore
92103
except Exception as e:
93104
warnings.warn(
94-
f"An exception occurred while dispatching event {event} to handler {handler}: {e}",
105+
f"An exception occurred while dispatching event {event} to handler {metadata.handler}: {e}",
95106
RuntimeWarning,
96107
)
97108
finally:
98109
handler_queue.task_done()
99110

100111
def start(self):
101112
for source in self.queues:
102-
for handler, queue in self.queues[source].items():
103-
self.futures[handler] = self.executor.submit(
104-
self.process, handler, queue
105-
)
113+
for metadata, handler_queue in self.queues[source].items():
114+
if not metadata.sync:
115+
# Only submit async handlers to the thread pool
116+
self.futures[metadata] = self.executor.submit(
117+
self.process, metadata, handler_queue
118+
)
106119

107120
def stop(self, wait: bool = True):
108121
if not wait:
109122
self.executor.shutdown(wait=False)
110123
return
111124

112125
for source in self.queues:
113-
for queue in self.queues[source].values():
114-
queue.join()
115-
queue.put(Shutdown)
126+
for metadata, handler_queue in self.queues[source].items():
127+
if not metadata.sync:
128+
# Only wait for and shutdown async handlers
129+
handler_queue.join()
130+
handler_queue.put(Shutdown)
116131
self.executor.shutdown(wait=True)

src/python-api/getml/events/handlers.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,93 @@
1010
Event system.
1111
"""
1212

13-
from typing import Callable, Dict, List
13+
from dataclasses import dataclass
14+
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
1415

1516
from getml.events.types import Event, EventSource
1617

1718
EventHandler = Callable[[Event], None]
19+
H = TypeVar("H", bound=EventHandler)
20+
21+
22+
@dataclass(frozen=True)
23+
class HandlerMetadata:
24+
"""Metadata for an event handler."""
25+
26+
handler: EventHandler
27+
sync: bool = False
1828

1929

2030
class EventHandlerRegistry:
21-
handlers: Dict[EventSource, List[EventHandler]] = {}
31+
handlers: Dict[EventSource, List[HandlerMetadata]] = {}
2232

2333
def __init__(self, source: EventSource):
2434
self.source = source
2535
type(self).handlers[source] = []
2636

27-
def __call__(self, handler: EventHandler) -> EventHandler:
37+
@overload
38+
def __call__(self, handler: H) -> H: ...
39+
40+
@overload
41+
def __call__(self, *, sync: bool = False) -> Callable[[H], H]: ...
42+
43+
def __call__(
44+
self, handler: Optional[H] = None, *, sync: bool = False
45+
) -> Union[Callable[[H], H], H]:
2846
"""
2947
Register an event handler for the source associated with the instance.
48+
49+
Can be used as a decorator with or without arguments:
50+
- @event_handler
51+
- @event_handler(sync=True)
52+
53+
Args:
54+
handler: The event handler function
55+
sync: If True, handler executes synchronously in the calling thread
56+
(no dispatching). If False, handler executes in a worker thread pool.
57+
58+
Returns:
59+
The handler function (for decorator usage)
3060
"""
31-
self.register_handler(self.source, handler)
61+
if handler is None:
62+
# Called with arguments: @event_handler(sync=...)
63+
def decorator(func: H) -> H:
64+
self.register_handler(self.source, func, sync=sync)
65+
return func
66+
67+
return decorator
68+
69+
# Called without arguments: @event_handler
70+
self.register_handler(self.source, handler, sync=sync)
3271
return handler
3372

3473
@classmethod
35-
def register_handler(cls, source: EventSource, handler: EventHandler):
74+
def register_handler(
75+
cls, source: EventSource, handler: EventHandler, sync: bool = False
76+
):
3677
"""
3778
Register an event handler for a specific source.
79+
80+
Args:
81+
source: The event source
82+
handler: The event handler function
83+
sync: If True, handler executes synchronously in the calling thread
84+
"""
85+
metadata = HandlerMetadata(handler=handler, sync=sync)
86+
cls.handlers[source].append(metadata)
87+
88+
@classmethod
89+
def unregister_handler(cls, source: EventSource, handler: EventHandler):
90+
"""
91+
Unregister an event handler for a specific source.
92+
93+
Args:
94+
source: The event source
95+
handler: The event handler function to remove
3896
"""
39-
cls.handlers[source].append(handler)
97+
cls.handlers[source] = [
98+
metadata for metadata in cls.handlers[source] if metadata.handler != handler
99+
]
40100

41101

42102
engine_event_handler = EventHandlerRegistry(EventSource.ENGINE)

src/python-api/getml/events/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class EventTypeState(str, Enum):
3030

3131

3232
class EventType(str, Enum):
33+
COMMAND_FINISHED = "command.finished"
3334
PIPELINE_CHECK_STAGE_START = "pipeline.check.stage.start"
3435
PIPELINE_CHECK_STAGE_PROGRESS = "pipeline.check.stage.progress"
3536
PIPELINE_CHECK_STAGE_END = "pipeline.check.stage.end"

src/python-api/getml/utilities/progress.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from rich.style import Style
3434

3535
from getml.events.handlers import engine_event_handler, monitor_event_handler
36-
from getml.events.types import Event, EventTypeState
36+
from getml.events.types import Event, EventType, EventTypeState
3737

3838

3939
def _is_emacs_kernel() -> bool:
@@ -89,6 +89,8 @@ def _should_enforce_textual_output() -> bool:
8989
def _should_enforce_monochrome_output() -> bool:
9090
if env_value := os.getenv("GETML_PROGRESS_FORCE_MONOCHROME_OUTPUT"):
9191
return env_value in _ENV_VAR_TRUTHY_VALUES
92+
if _is_vscode():
93+
return True
9294
return False
9395

9496

@@ -314,7 +316,7 @@ def new_task(
314316
)
315317
return task_id
316318

317-
def new_task_with_reconstructed_renderable(
319+
def new_task_and_reconstruct_renderable_maybe(
318320
self,
319321
description: str,
320322
*,
@@ -327,7 +329,9 @@ def new_task_with_reconstructed_renderable(
327329
renderable to ensure that the progress bar is displayed correctly
328330
with interleaved `print` statements.
329331
"""
330-
self._reconstruct_renderable()
332+
if _should_enforce_textual_output():
333+
self._reconstruct_renderable()
334+
331335
task_id = self.new_task(
332336
description, total=total, completed=completed, visible=visible
333337
)
@@ -426,17 +430,22 @@ def __init__(self, progress: Progress):
426430
@contextmanager
427431
def live_renderable(self, event: Event) -> Iterator[Live]:
428432
yield self.progress.live
429-
if event.type.state is EventTypeState.END:
433+
if event.type is EventType.COMMAND_FINISHED:
430434
self.progress.stop()
435+
self.progress = Progress()
431436

432437
def create_or_update_progress(self, event: Event):
433438
"""
434439
Create a new progress task or update an existing one.
440+
441+
A new Progress is created lazily if all previous tasks have finished. Progresses
442+
cannot be created eagerly, because construction a live renderable in a
443+
textual environment results in control characters being printed to the output.
435444
"""
436445

437446
with self.live_renderable(event):
438447
if event.type.state is EventTypeState.START:
439-
task_id = self.progress.new_task_with_reconstructed_renderable(
448+
task_id = self.progress.new_task_and_reconstruct_renderable_maybe(
440449
event.attributes["description"]
441450
)
442451
elif event.type.state is EventTypeState.PROGRESS and self.progress.tasks:
@@ -451,8 +460,8 @@ def create_or_update_progress(self, event: Event):
451460
progress_event_handler = ProgressEventHandler(progress=Progress())
452461

453462

454-
@monitor_event_handler
455-
@engine_event_handler
463+
@monitor_event_handler(sync=True)
464+
@engine_event_handler(sync=True)
456465
def show_progress(event: Event):
457466
if DISABLE:
458467
return

0 commit comments

Comments
 (0)