diff --git a/src/stopit/threadstop.py b/src/stopit/threadstop.py index a991750..eb59c04 100644 --- a/src/stopit/threadstop.py +++ b/src/stopit/threadstop.py @@ -12,7 +12,7 @@ import sys import threading -from .utils import TimeoutException, BaseTimeout, base_timeoutable +from .utils import LOG, TimeoutException, BaseTimeout, base_timeoutable if sys.version_info < (3, 7): tid_ctype = ctypes.c_long @@ -30,8 +30,9 @@ def async_raise(target_tid, exception): """ # Ensuring and releasing GIL are useless since we're not in C # gil_state = ctypes.pythonapi.PyGILState_Ensure() - ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid_ctype(target_tid), - ctypes.py_object(exception)) + ret = ctypes.pythonapi.PyThreadState_SetAsyncExc( + tid_ctype(target_tid), ctypes.py_object(exception) + ) # ctypes.pythonapi.PyGILState_Release(gil_state) if ret == 0: raise ValueError("Invalid thread ID {}".format(target_tid)) @@ -46,36 +47,72 @@ class ThreadingTimeout(BaseTimeout): See :class:`stopit.utils.BaseTimeout` for more information """ + + # This class property keep track about who produced the + # exception. + exception_source = None + def __init__(self, seconds, swallow_exc=True): + # Ensure that any new handler find a clear + # pointer super(ThreadingTimeout, self).__init__(seconds, swallow_exc) self.target_tid = threading.current_thread().ident self.timer = None # PEP8 + def __enter__(self): + self.__class__.exception_source = None + self.state = BaseTimeout.EXECUTING + self.setup_interrupt() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + exc_src = self.__class__.exception_source + if exc_type is TimeoutException: + if self.state != BaseTimeout.TIMED_OUT: + self.state = BaseTimeout.INTERRUPTED + self.suppress_interrupt() + LOG.warning( + "Code block execution exceeded {0} seconds timeout".format( + self.seconds + ), + exc_info=(exc_type, exc_val, exc_tb), + ) + if exc_src is self: + if self.swallow_exc: + self.__class__.exception_source = None + return True + return False + else: + if exc_type is None: + self.state = BaseTimeout.EXECUTED + self.suppress_interrupt() + return False + def stop(self): """Called by timer thread at timeout. Raises a Timeout exception in the caller thread """ self.state = BaseTimeout.TIMED_OUT + self.__class__.exception_source = self async_raise(self.target_tid, TimeoutException) # Required overrides def setup_interrupt(self): - """Setting up the resource that interrupts the block - """ + """Setting up the resource that interrupts the block""" self.timer = threading.Timer(self.seconds, self.stop) self.timer.start() def suppress_interrupt(self): - """Removing the resource that interrupts the block - """ + """Removing the resource that interrupts the block""" self.timer.cancel() -class threading_timeoutable(base_timeoutable): #noqa +class threading_timeoutable(base_timeoutable): # noqa """A function or method decorator that raises a ``TimeoutException`` to decorated functions that should not last a certain amount of time. this one uses ``ThreadingTimeout`` context manager. See :class:`.utils.base_timoutable`` class for further comments. """ + to_ctx_mgr = ThreadingTimeout diff --git a/tests.py b/tests.py index 9b1ce5f..c735eee 100644 --- a/tests.py +++ b/tests.py @@ -1,30 +1,92 @@ # -*- coding: utf-8 -*- +import time import doctest import os import unittest -from stopit import ThreadingTimeout, threading_timeoutable, SignalTimeout, signal_timeoutable +from stopit import ( + TimeoutException, + ThreadingTimeout, + threading_timeoutable, + SignalTimeout, + signal_timeoutable, +) # We run twice the same doctest with two distinct sets of globs # This one is for testing signals based timeout control -signaling_globs = { - 'Timeout': SignalTimeout, - 'timeoutable': signal_timeoutable -} +signaling_globs = {"Timeout": SignalTimeout, "timeoutable": signal_timeoutable} # And this one is for testing threading based timeout control -threading_globs = { - 'Timeout': ThreadingTimeout, - 'timeoutable': threading_timeoutable -} +threading_globs = {"Timeout": ThreadingTimeout, "timeoutable": threading_timeoutable} + + +class TestNesting(unittest.TestCase): + handlers = ( + (ThreadingTimeout,) # SignalTimeout, + if os.name == "posix" + else (ThreadingTimeOut,) + ) + + def aware_wait(self, duration): + remaining = duration * 100 + while remaining > 0: + time.sleep(0.01) + remaining = remaining - 1 + return 0 + + def check_nest(self, t1, t2, duration, HandlerClass): + try: + with HandlerClass(t1, swallow_exc=False) as to_ctx_mgr1: + assert to_ctx_mgr1.state == to_ctx_mgr1.EXECUTING + with HandlerClass(t2, swallow_exc=False) as to_ctx_mgr2: + assert to_ctx_mgr2.state == to_ctx_mgr2.EXECUTING + self.aware_wait(duration) + return "success" + except TimeoutException: + if ThreadingTimeout.exception_source is to_ctx_mgr1: + return "outer" + elif ThreadingTimeout.exception_source is to_ctx_mgr2: + return "inner" + else: + print(ThreadingTimeout.exception_source) + return "unknown source" + + def check_nest_swallow(self, t1, t2, duration, HandlerClass): + with HandlerClass(t1) as to_ctx_mgr1: + assert to_ctx_mgr1.state == to_ctx_mgr1.EXECUTING + with HandlerClass(t2) as to_ctx_mgr2: + assert to_ctx_mgr2.state == to_ctx_mgr2.EXECUTING + self.aware_wait(duration) + return "success" + return "inner" + return "outer" + + def test_nested_long_inner(self): + for handler in self.handlers: + self.assertEqual(self.check_nest(1.0, 10.0, 5.0, handler), "outer") + self.assertEqual(self.check_nest_swallow(1.0, 10.0, 5.0, handler), "outer") + + def test_nested_success(self): + for handler in self.handlers: + self.assertEqual(self.check_nest(5.0, 10.0, 1.0, handler), "success") + self.assertEqual( + self.check_nest_swallow(5.0, 10.0, 1.0, handler), "success" + ) + + def test_nested_long_outer(self): + for handler in self.handlers: + self.assertEqual(self.check_nest(10.0, 1.0, 5.0, handler), "inner") + self.assertEqual(self.check_nest_swallow(10.0, 1.0, 5.0, handler), "inner") def suite(): # Func for setuptools.setup(test_suite=xxx) test_suite = unittest.TestSuite() - test_suite.addTest(doctest.DocFileSuite('README.rst', globs=signaling_globs)) - if os.name == 'posix': # Other OS have no support for signal.SIGALRM - test_suite.addTest(doctest.DocFileSuite('README.rst', globs=threading_globs)) + test_suite.addTest(doctest.DocFileSuite("README.rst", globs=threading_globs)) + if os.name == "posix": # Other OS have no support for signal.SIGALRM + test_suite.addTest(doctest.DocFileSuite("README.rst", globs=signaling_globs)) return test_suite -if __name__ == '__main__': + +if __name__ == "__main__": unittest.TextTestRunner(verbosity=2).run(suite()) + unittest.main()