diff --git a/changes/1745.bugfix.md b/changes/1745.bugfix.md new file mode 100644 index 0000000000..36a7ee3dd1 --- /dev/null +++ b/changes/1745.bugfix.md @@ -0,0 +1,3 @@ +Improve handing of force exiting a bot (double interrupt) +- Improve exception message +- Reset signal handlers to original ones after no longer capturing signals diff --git a/hikari/errors.py b/hikari/errors.py index 1bdacaf732..aa33d3d472 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -102,6 +102,9 @@ class HikariInterrupt(KeyboardInterrupt, HikariError): signame: str = attrs.field() """The signal name that was raised.""" + def __str__(self) -> str: + return f"Signal {self.signum} ({self.signame}) received" + @attrs.define(auto_exc=True, repr=False, slots=False) class ComponentStateConflictError(HikariError): diff --git a/hikari/impl/gateway_bot.py b/hikari/impl/gateway_bot.py index debe57bb42..6ff48a1cac 100644 --- a/hikari/impl/gateway_bot.py +++ b/hikari/impl/gateway_bot.py @@ -805,10 +805,10 @@ def run( except AttributeError: _LOGGER.log(ux.TRACE, "cannot set coroutine tracking depth for sys, no functionality exists for this") - try: - with signals.handle_interrupts( - enabled=enable_signal_handlers, loop=loop, propagate_interrupts=propagate_interrupts - ): + with signals.handle_interrupts( + enabled=enable_signal_handlers, loop=loop, propagate_interrupts=propagate_interrupts + ): + try: loop.run_until_complete( self.start( activity=activity, @@ -825,22 +825,27 @@ def run( loop.run_until_complete(self.join()) - finally: - if self._closing_event: - if self._closing_event.is_set(): - loop.run_until_complete(self._closing_event.wait()) - else: - loop.run_until_complete(self.close()) + finally: + try: + if self._closing_event: + if self._closing_event.is_set(): + loop.run_until_complete(self._closing_event.wait()) + else: + loop.run_until_complete(self.close()) + + if close_passed_executor and self._executor is not None: + _LOGGER.debug("shutting down executor %s", self._executor) + self._executor.shutdown(wait=True) + self._executor = None - if close_passed_executor and self._executor is not None: - _LOGGER.debug("shutting down executor %s", self._executor) - self._executor.shutdown(wait=True) - self._executor = None + if close_loop: + aio.destroy_loop(loop, _LOGGER) - if close_loop: - aio.destroy_loop(loop, _LOGGER) + _LOGGER.info("successfully terminated") - _LOGGER.info("successfully terminated") + except errors.HikariInterrupt: + _LOGGER.warning("forcefully terminated") + raise async def start( self, diff --git a/hikari/impl/rest_bot.py b/hikari/impl/rest_bot.py index 0613d80c3e..eeb10f9373 100644 --- a/hikari/impl/rest_bot.py +++ b/hikari/impl/rest_bot.py @@ -568,10 +568,10 @@ def run( except AttributeError: _LOGGER.log(ux.TRACE, "cannot set coroutine tracking depth for sys, no functionality exists for this") - try: - with signals.handle_interrupts( - enabled=enable_signal_handlers, loop=loop, propagate_interrupts=propagate_interrupts - ): + with signals.handle_interrupts( + enabled=enable_signal_handlers, loop=loop, propagate_interrupts=propagate_interrupts + ): + try: loop.run_until_complete( self.start( backlog=backlog, @@ -588,22 +588,27 @@ def run( ) loop.run_until_complete(self.join()) - finally: - if self._close_event: - if self._is_closing: - loop.run_until_complete(self._close_event.wait()) - else: - loop.run_until_complete(self.close()) + finally: + try: + if self._close_event: + if self._is_closing: + loop.run_until_complete(self._close_event.wait()) + else: + loop.run_until_complete(self.close()) + + if close_passed_executor and self._executor: + _LOGGER.debug("shutting down executor %s", self._executor) + self._executor.shutdown(wait=True) + self._executor = None - if close_passed_executor and self._executor: - _LOGGER.debug("shutting down executor %s", self._executor) - self._executor.shutdown(wait=True) - self._executor = None + if close_loop: + aio.destroy_loop(loop, _LOGGER) - if close_loop: - aio.destroy_loop(loop, _LOGGER) + _LOGGER.info("successfully terminated") - _LOGGER.info("successfully terminated") + except errors.HikariInterrupt: + _LOGGER.warning("forcefully terminated") + raise async def start( self, diff --git a/hikari/internal/aio.py b/hikari/internal/aio.py index 489e300be9..3b7b75e6d1 100644 --- a/hikari/internal/aio.py +++ b/hikari/internal/aio.py @@ -259,7 +259,7 @@ async def murder(future: asyncio.Future[typing.Any]) -> None: remaining_tasks = tuple(t for t in asyncio.all_tasks(loop) if not t.done()) if remaining_tasks: - logger.debug("terminating %s remaining tasks forcefully", len(remaining_tasks)) + logger.warning("terminating %s remaining tasks forcefully", len(remaining_tasks)) loop.run_until_complete(_gather((murder(task) for task in remaining_tasks))) else: logger.debug("No remaining tasks exist, good job!") diff --git a/hikari/internal/signals.py b/hikari/internal/signals.py index 48220c7496..11dcfdca13 100644 --- a/hikari/internal/signals.py +++ b/hikari/internal/signals.py @@ -38,6 +38,9 @@ from hikari import errors from hikari.internal import ux +if typing.TYPE_CHECKING: + _SignalHandlerT = typing.Callable[[int, typing.Optional[types.FrameType]], None] + _INTERRUPT_SIGNALS: typing.Tuple[str, ...] = ("SIGINT", "SIGTERM") _LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.signals") @@ -49,9 +52,7 @@ def _raise_interrupt(signum: int) -> typing.NoReturn: raise errors.HikariInterrupt(signum, signame) -def _interrupt_handler( - loop: asyncio.AbstractEventLoop, -) -> typing.Callable[[int, typing.Optional[types.FrameType]], None]: +def _interrupt_handler(loop: asyncio.AbstractEventLoop) -> _SignalHandlerT: loop_thread_id = threading.get_native_id() def handler(signum: int, frame: typing.Optional[types.FrameType]) -> None: @@ -102,13 +103,16 @@ def handle_interrupts( return interrupt_handler = _interrupt_handler(loop) + original_handlers: typing.Dict[int, typing.Union[int, _SignalHandlerT, None]] = {} for sig in _INTERRUPT_SIGNALS: try: signum = getattr(signal, sig) - signal.signal(signum, interrupt_handler) except AttributeError: _LOGGER.log(ux.TRACE, "signal %s is not implemented on your platform; skipping", sig) + else: + original_handlers[signum] = signal.getsignal(signum) + signal.signal(signum, interrupt_handler) try: yield @@ -118,10 +122,5 @@ def handle_interrupts( raise finally: - for sig in _INTERRUPT_SIGNALS: - try: - signum = getattr(signal, sig) - signal.signal(signum, signal.SIG_DFL) - except AttributeError: - # Signal not implemented. We already logged this earlier. - pass + for signum, handler in original_handlers.items(): + signal.signal(signum, handler) diff --git a/tests/hikari/internal/test_signals.py b/tests/hikari/internal/test_signals.py index e32a911b78..23a331c759 100644 --- a/tests/hikari/internal/test_signals.py +++ b/tests/hikari/internal/test_signals.py @@ -70,7 +70,9 @@ def test_behaviour(self): register_signal_handler.reset_mock() assert register_signal_handler.call_count == 2 - register_signal_handler.assert_has_calls([mock.call(2, signal.SIG_DFL), mock.call(15, signal.SIG_DFL)]) + register_signal_handler.assert_has_calls( + [mock.call(2, signal.default_int_handler), mock.call(15, signal.SIG_DFL)] + ) def test_when_disabled(self): with mock.patch.object(signal, "signal") as register_signal_handler: