From 5781be24ca32afa8e471070f1dde93e99437b2e7 Mon Sep 17 00:00:00 2001 From: Bradley Augstein Date: Thu, 9 Jan 2025 08:37:53 +0000 Subject: [PATCH] combine widget and script loggers --- heracles/notebook.py | 133 ++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 70 deletions(-) diff --git a/heracles/notebook.py b/heracles/notebook.py index 7576c50..7978a15 100644 --- a/heracles/notebook.py +++ b/heracles/notebook.py @@ -29,102 +29,95 @@ from IPython.display import display import sys -from typing import List +from typing import List, Union -class Progress: - """ - Progress bar using ipywidgets. - """ - - def __init__(self, label: str, *, box: widgets.Box | None = None) -> None: - if box is None: - self.box = widgets.VBox() - else: - self.box = box - self.widget = widgets.IntProgress( - value=0, - min=0, - max=1, - description=label, - orientation="horizontal", - ) - - def __enter__(self) -> "Progress": - if not self.box.children: - display(self.box) - self.box.children += (self.widget,) - return self - - def __exit__(self, *exc) -> None: - self.widget.close() - try: - index = self.box.children.index(self.widget) - except ValueError: - pass - else: - self.box.children = ( - self.box.children[:index] + self.box.children[index + 1 :] - ) - if not self.box.children: - self.box.close() +def is_notebook() -> bool: + try: + from IPython import get_ipython + if "IPKernelApp" in get_ipython().config: + return True + except Exception: + return False + return False - def update(self, current: int | None = None, total: int | None = None) -> None: - if current is not None: - self.widget.value = current - if total is not None: - self.widget.max = total - - def task(self, label: str) -> "Progress": - return self.__class__(label, box=self.box) - -class ProgressLogging: +class Progress: """ - Progress bar without GUI interface. + Combined progress bar that can use ipywidgets or a text-based progress bar. """ - def __init__(self, label: str, *, box: List["ProgressLogging"] = None) -> None: + def __init__(self, label: str, *, use_widgets: bool = True, box: Union[widgets.Box, List["Progress"]] = None) -> None: + self.use_widgets = use_widgets self.label = label self.current = 0 self.total = 1 # Default to 1 to avoid division by zero - self.box = box if box is not None else [] - self.line_offset = len(self.box) # Track which line to overwrite - sys.stdout.write("\n") - def __enter__(self) -> "ProgressLogging": - # Add this instance to the box if it's not already there - if self not in self.box: - self.box.append(self) - self._display_box() - #Ensure the cursor ends at a new line after the progress bars + if self.use_widgets: + self.box = box if box is not None else widgets.VBox() + self.widget = widgets.IntProgress( + value=0, + min=0, + max=1, + description=label, + orientation="horizontal", + ) + else: + if is_notebook(): + raise Exception("use_widgets=False - Cannot use Progress without widgets in notebook.") + self.box = box if box is not None else [] + self.line_offset = len(self.box) # Track which line to overwrite + sys.stdout.write("\n") + + def __enter__(self) -> "Progress": + if self.use_widgets: + if not self.box.children: + display(self.box) + self.box.children += (self.widget,) + else: + if self not in self.box: + self.box.append(self) + self._display_terminal() return self def __exit__(self, exc_type, exc_value, traceback) -> None: - self._display_box() + if self.use_widgets: + self.widget.close() + try: + index = self.box.children.index(self.widget) + except ValueError: + pass + else: + self.box.children = ( + self.box.children[:index] + self.box.children[index + 1 :] + ) + if not self.box.children: + self.box.close() + else: + self._display_terminal() def update(self, current: int | None = None, total: int | None = None) -> None: - # Update progress values if current is not None: self.current = current if total is not None: self.total = total - # Refresh the entire display box - self._display_box() - def task(self, label: str) -> "ProgressLogging": - # Create a new task tied to the same box - return self.__class__(label, box=self.box) + if self.use_widgets: + self.widget.value = self.current + self.widget.max = self.total + else: + self._display_terminal() + + def task(self, label: str) -> "Progress": + return self.__class__(label, use_widgets=self.use_widgets, box=self.box) - def _display_box(self) -> None: + def _display_terminal(self) -> None: """ Redraw the progress bars in the terminal for all tasks in the box. """ - # Move the cursor up to overwrite only the progress bar lines - sys.stdout.write(f"\033[{len(self.box)}F") # Move up N lines + sys.stdout.write(f"\033[{len(self.box)}F") # Move curser up N lines to redraw sys.stdout.flush() - # Display all progress bars for task in self.box: percentage = (task.current / task.total) * 100 - bar_length = 40 # Fixed number of blocks in the bar + bar_length = 40 progress_blocks = int(percentage // (100 / bar_length)) bar = "=" * progress_blocks + " " * (bar_length - progress_blocks) sys.stdout.write(f"\r{task.label}: [{bar}]\n")