Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions src/stopit/threadstop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
import threading

from .utils import TimeoutException, BaseTimeout, base_timeoutable
from .utils import LOG, TimeoutException, BaseTimeout, base_timeoutable

if sys.version_info < (3, 7):
tid_ctype = ctypes.c_long
Expand All @@ -30,8 +30,9 @@ def async_raise(target_tid, exception):
"""
# Ensuring and releasing GIL are useless since we're not in C
# gil_state = ctypes.pythonapi.PyGILState_Ensure()
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid_ctype(target_tid),
ctypes.py_object(exception))
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
tid_ctype(target_tid), ctypes.py_object(exception)
)
# ctypes.pythonapi.PyGILState_Release(gil_state)
if ret == 0:
raise ValueError("Invalid thread ID {}".format(target_tid))
Expand All @@ -46,36 +47,72 @@ class ThreadingTimeout(BaseTimeout):

See :class:`stopit.utils.BaseTimeout` for more information
"""

# This class property keep track about who produced the
# exception.
exception_source = None

def __init__(self, seconds, swallow_exc=True):
# Ensure that any new handler find a clear
# pointer
super(ThreadingTimeout, self).__init__(seconds, swallow_exc)
self.target_tid = threading.current_thread().ident
self.timer = None # PEP8

def __enter__(self):
self.__class__.exception_source = None
self.state = BaseTimeout.EXECUTING
self.setup_interrupt()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
exc_src = self.__class__.exception_source
if exc_type is TimeoutException:
if self.state != BaseTimeout.TIMED_OUT:
self.state = BaseTimeout.INTERRUPTED
self.suppress_interrupt()
LOG.warning(
"Code block execution exceeded {0} seconds timeout".format(
self.seconds
),
exc_info=(exc_type, exc_val, exc_tb),
)
if exc_src is self:
if self.swallow_exc:
self.__class__.exception_source = None
return True
return False
else:
if exc_type is None:
self.state = BaseTimeout.EXECUTED
self.suppress_interrupt()
return False

def stop(self):
"""Called by timer thread at timeout. Raises a Timeout exception in the
caller thread
"""
self.state = BaseTimeout.TIMED_OUT
self.__class__.exception_source = self
async_raise(self.target_tid, TimeoutException)

# Required overrides
def setup_interrupt(self):
"""Setting up the resource that interrupts the block
"""
"""Setting up the resource that interrupts the block"""
self.timer = threading.Timer(self.seconds, self.stop)
self.timer.start()

def suppress_interrupt(self):
"""Removing the resource that interrupts the block
"""
"""Removing the resource that interrupts the block"""
self.timer.cancel()


class threading_timeoutable(base_timeoutable): #noqa
class threading_timeoutable(base_timeoutable): # noqa
"""A function or method decorator that raises a ``TimeoutException`` to
decorated functions that should not last a certain amount of time.
this one uses ``ThreadingTimeout`` context manager.

See :class:`.utils.base_timoutable`` class for further comments.
"""

to_ctx_mgr = ThreadingTimeout
88 changes: 75 additions & 13 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,92 @@
# -*- coding: utf-8 -*-
import time
import doctest
import os
import unittest

from stopit import ThreadingTimeout, threading_timeoutable, SignalTimeout, signal_timeoutable
from stopit import (
TimeoutException,
ThreadingTimeout,
threading_timeoutable,
SignalTimeout,
signal_timeoutable,
)

# We run twice the same doctest with two distinct sets of globs
# This one is for testing signals based timeout control
signaling_globs = {
'Timeout': SignalTimeout,
'timeoutable': signal_timeoutable
}
signaling_globs = {"Timeout": SignalTimeout, "timeoutable": signal_timeoutable}

# And this one is for testing threading based timeout control
threading_globs = {
'Timeout': ThreadingTimeout,
'timeoutable': threading_timeoutable
}
threading_globs = {"Timeout": ThreadingTimeout, "timeoutable": threading_timeoutable}


class TestNesting(unittest.TestCase):
handlers = (
(ThreadingTimeout,) # SignalTimeout,
if os.name == "posix"
else (ThreadingTimeOut,)
)

def aware_wait(self, duration):
remaining = duration * 100
while remaining > 0:
time.sleep(0.01)
remaining = remaining - 1
return 0

def check_nest(self, t1, t2, duration, HandlerClass):
try:
with HandlerClass(t1, swallow_exc=False) as to_ctx_mgr1:
assert to_ctx_mgr1.state == to_ctx_mgr1.EXECUTING
with HandlerClass(t2, swallow_exc=False) as to_ctx_mgr2:
assert to_ctx_mgr2.state == to_ctx_mgr2.EXECUTING
self.aware_wait(duration)
return "success"
except TimeoutException:
if ThreadingTimeout.exception_source is to_ctx_mgr1:
return "outer"
elif ThreadingTimeout.exception_source is to_ctx_mgr2:
return "inner"
else:
print(ThreadingTimeout.exception_source)
return "unknown source"

def check_nest_swallow(self, t1, t2, duration, HandlerClass):
with HandlerClass(t1) as to_ctx_mgr1:
assert to_ctx_mgr1.state == to_ctx_mgr1.EXECUTING
with HandlerClass(t2) as to_ctx_mgr2:
assert to_ctx_mgr2.state == to_ctx_mgr2.EXECUTING
self.aware_wait(duration)
return "success"
return "inner"
return "outer"

def test_nested_long_inner(self):
for handler in self.handlers:
self.assertEqual(self.check_nest(1.0, 10.0, 5.0, handler), "outer")
self.assertEqual(self.check_nest_swallow(1.0, 10.0, 5.0, handler), "outer")

def test_nested_success(self):
for handler in self.handlers:
self.assertEqual(self.check_nest(5.0, 10.0, 1.0, handler), "success")
self.assertEqual(
self.check_nest_swallow(5.0, 10.0, 1.0, handler), "success"
)

def test_nested_long_outer(self):
for handler in self.handlers:
self.assertEqual(self.check_nest(10.0, 1.0, 5.0, handler), "inner")
self.assertEqual(self.check_nest_swallow(10.0, 1.0, 5.0, handler), "inner")


def suite(): # Func for setuptools.setup(test_suite=xxx)
test_suite = unittest.TestSuite()
test_suite.addTest(doctest.DocFileSuite('README.rst', globs=signaling_globs))
if os.name == 'posix': # Other OS have no support for signal.SIGALRM
test_suite.addTest(doctest.DocFileSuite('README.rst', globs=threading_globs))
test_suite.addTest(doctest.DocFileSuite("README.rst", globs=threading_globs))
if os.name == "posix": # Other OS have no support for signal.SIGALRM
test_suite.addTest(doctest.DocFileSuite("README.rst", globs=signaling_globs))
return test_suite

if __name__ == '__main__':

if __name__ == "__main__":
unittest.TextTestRunner(verbosity=2).run(suite())
unittest.main()