From 8fcb9d7689433c0db8d82634d48178d346e3d84d Mon Sep 17 00:00:00 2001 From: d-w-moore Date: Sun, 18 May 2025 22:55:25 -0400 Subject: [PATCH 1/2] [_722] fix segfault and hung threads on SIGINT during parallel get --- irods/parallel.py | 65 +++++++--- irods/test/data_obj_test.py | 7 ++ ...test_signal_handling_in_multithread_get.py | 118 ++++++++++++++++++ 3 files changed, 173 insertions(+), 17 deletions(-) create mode 100644 irods/test/modules/test_signal_handling_in_multithread_get.py diff --git a/irods/parallel.py b/irods/parallel.py index 9a7490efa..e7794b420 100644 --- a/irods/parallel.py +++ b/irods/parallel.py @@ -9,6 +9,7 @@ import concurrent.futures import threading import multiprocessing +import weakref from irods.data_object import iRODSDataObject from irods.exception import DataObjectDoesNotExist @@ -16,6 +17,14 @@ from queue import Queue, Full, Empty +transfer_managers = weakref.WeakKeyDictionary() + + +def abort_asynchronous_transfers(): + for mgr in transfer_managers: + mgr.quit() + + logger = logging.getLogger(__name__) _nullh = logging.NullHandler() logger.addHandler(_nullh) @@ -90,9 +99,11 @@ def __init__( for future in self._futures: future.add_done_callback(self) else: - self.__invoke_done_callback() + self.__invoke_futures_done_logic() + return self.progress = [0, 0] + if (progress_Queue) and (total is not None): self.progress[1] = total @@ -111,7 +122,7 @@ def _progress(Q, this): # - thread to update progress indicator self._progress_fn = _progress self._progress_thread = threading.Thread( - target=self._progress_fn, args=(progress_Queue, self) + target=self._progress_fn, args=(progress_Queue, self), daemon=True ) self._progress_thread.start() @@ -152,11 +163,13 @@ def __call__( with self._lock: self._futures_done[future] = future.result() if len(self._futures) == len(self._futures_done): - self.__invoke_done_callback() + self.__invoke_futures_done_logic( + skip_user_callback=(None in self._futures_done.values()) + ) - def __invoke_done_callback(self): + def __invoke_futures_done_logic(self, skip_user_callback=False): try: - if callable(self.done_callback): + if not skip_user_callback and callable(self.done_callback): self.done_callback(self) finally: self.keep.pop("mgr", None) @@ -239,6 +252,9 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()): bytecount = 0 accum = 0 while True and bytecount < length: + if mgr._quit: + bytecount = None + break buf = src.read(min(COPY_BUF_SIZE, length - bytecount)) buf_len = len(buf) if 0 == buf_len: @@ -274,11 +290,16 @@ class _Multipart_close_manager: """ def __init__(self, initial_io_, exit_barrier_): + self._quit = False self.exit_barrier = exit_barrier_ self.initial_io = initial_io_ self.__lock = threading.Lock() self.aux = [] + def quit(self): + self._quit = True + self.exit_barrier.abort() + def __contains__(self, Io): with self.__lock: return Io is self.initial_io or Io in self.aux @@ -303,8 +324,12 @@ def remove_io(self, Io): Io.close() self.aux.remove(Io) is_initial = False - self.exit_barrier.wait() - if is_initial: + broken = False + try: + self.exit_barrier.wait() + except threading.BrokenBarrierError: + broken = True + if is_initial and not (broken or self._quit): self.finalize() def finalize(self): @@ -439,13 +464,19 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): Io = File = None if Operation.isNonBlocking(): - if queueLength: - return futures, queueObject, mgr - else: - return futures + return futures, queueObject, mgr else: - bytecounts = [f.result() for f in futures] - return sum(bytecounts), total_size + bytes_transferred = 0 + try: + bytecounts = [f.result() for f in futures] + if None not in bytecounts: + bytes_transferred = sum(bytecounts) + except KeyboardInterrupt: + if any(not f.done() for f in futures): + # Induce any threads still alive to quit the transfer and exit. + mgr.quit() + raise + return bytes_transferred, total_size def io_main(session, Data, opr_, fname, R="", **kwopt): @@ -558,10 +589,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt): if Operation.isNonBlocking(): - if queueLength > 0: - (futures, chunk_notify_queue, mgr) = retval - else: - futures = retval + (futures, chunk_notify_queue, mgr) = retval + transfer_managers[mgr] = None + + if queueLength <= 0: chunk_notify_queue = total_bytes = None return AsyncNotify( diff --git a/irods/test/data_obj_test.py b/irods/test/data_obj_test.py index 06bb075ee..69a9174c7 100644 --- a/irods/test/data_obj_test.py +++ b/irods/test/data_obj_test.py @@ -2955,6 +2955,13 @@ def test_replica_truncate__issue_534(self): if data_objs.exists(data_path): data_objs.unlink(data_path, force=True) + def test_handling_of_termination_signals_during_multithread_get__issue_722(self): + from irods.test.modules.test_signal_handling_in_multithread_get import ( + test as test__issue_722, + ) + + test__issue_722(self) + if __name__ == "__main__": # let the tests find the parent irods lib diff --git a/irods/test/modules/test_signal_handling_in_multithread_get.py b/irods/test/modules/test_signal_handling_in_multithread_get.py new file mode 100644 index 000000000..2665765bf --- /dev/null +++ b/irods/test/modules/test_signal_handling_in_multithread_get.py @@ -0,0 +1,118 @@ +import os +import re +import signal +import subprocess +import sys +import tempfile +import time + +import irods +import irods.helpers +from irods.test import modules as test_modules + +OBJECT_SIZE = 2 * 1024**3 +OBJECT_NAME = "data_get_issue__722" +LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat" + + +_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME)) + + +def wait_till_true(function, timeout=None): + start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME) + while not (truth_value := function()): + if ( + timeout is not None + and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9 + > timeout + ): + break + time.sleep(_clock_polling_interval) + return truth_value + + +def test(test_case, signal_names=("SIGTERM", "SIGINT")): + """Creates a child process executing a long get() and ensures the process can be + terminated using SIGINT or SIGTERM. + """ + program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) + + for signal_name in signal_names: + # Call into this same module as a command. This will initiate another Python process that + # performs a lengthy data object "get" operation (see the main body of the script, below.) + process = subprocess.Popen( + [sys.executable, program], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + + # Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions + # of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread + # unless measures are taken (#722). + localfile = process.stdout.readline().strip() + test_case.assertTrue( + wait_till_true( + lambda: os.path.exists(localfile) + and os.stat(localfile).st_size > OBJECT_SIZE // 2 + ), + "Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.", + ) + + signal_message_info = f"While testing signal {signal_name}" + sig = getattr(signal, signal_name) + + # Interrupt the subprocess with the given signal. + process.send_signal(sig) + # Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit + # due to misproper or incomplete handling of the signal. + try: + test_case.assertEqual( + process.wait(timeout=15), + -sig, + "{signal_message_info}: unexpected subprocess return code.", + ) + except subprocess.TimeoutExpired as timeout_exc: + test_case.fail( + f"{signal_message_info}: subprocess timed out before terminating. " + "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." + ) + # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. + if sig == signal.SIGINT: + test_case.assertTrue( + re.search("KeyboardInterrupt", process.stderr.read()), + "{signal_message_info}: Expected 'KeyboardInterrupt' in log output.", + ) + + +if __name__ == "__main__": + # These lines are run only if the module is launched as a process. + session = irods.helpers.make_session() + hc = irods.helpers.home_collection(session) + TESTFILE_FILL = b"_" * (1024 * 1024) + object_path = f"{hc}/{OBJECT_NAME}" + + # Create the object to be downloaded. + with session.data_objects.open(object_path, "w") as f: + for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): + f.write(TESTFILE_FILL) + local_path = None + # Establish where (ie absolute path) to place the downloaded file, i.e. the get() target. + try: + with tempfile.NamedTemporaryFile( + prefix="local_file_issue_722.dat", delete=True + ) as t: + local_path = t.name + + # Tell the parent process the name of the local file being "get"ted (got) from iRODS + print(local_path) + sys.stdout.flush() + + # "get" the object + session.data_objects.get(object_path, local_path) + finally: + # Clean up, whether or not the download succeeded. + if local_path is not None and os.path.exists(local_path): + os.unlink(local_path) + if session.data_objects.exists(object_path): + session.data_objects.unlink(object_path, force=True) From fb368364365bf42d4da626026938cb4b387e0d9c Mon Sep 17 00:00:00 2001 From: d-w-moore Date: Fri, 6 Jun 2025 04:05:03 -0400 Subject: [PATCH 2/2] use subtest. --- .../modules/test_signal_handling_in_multithread_get.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/irods/test/modules/test_signal_handling_in_multithread_get.py b/irods/test/modules/test_signal_handling_in_multithread_get.py index 2665765bf..7c01f9433 100644 --- a/irods/test/modules/test_signal_handling_in_multithread_get.py +++ b/irods/test/modules/test_signal_handling_in_multithread_get.py @@ -38,6 +38,9 @@ def test(test_case, signal_names=("SIGTERM", "SIGINT")): program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) for signal_name in signal_names: + + test_case.subTest(f"Testing with signal {signal_name}") + # Call into this same module as a command. This will initiate another Python process that # performs a lengthy data object "get" operation (see the main body of the script, below.) process = subprocess.Popen( @@ -59,7 +62,6 @@ def test(test_case, signal_names=("SIGTERM", "SIGINT")): "Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.", ) - signal_message_info = f"While testing signal {signal_name}" sig = getattr(signal, signal_name) # Interrupt the subprocess with the given signal. @@ -70,18 +72,18 @@ def test(test_case, signal_names=("SIGTERM", "SIGINT")): test_case.assertEqual( process.wait(timeout=15), -sig, - "{signal_message_info}: unexpected subprocess return code.", + "Unexpected subprocess return code.", ) except subprocess.TimeoutExpired as timeout_exc: test_case.fail( - f"{signal_message_info}: subprocess timed out before terminating. " + f"Subprocess timed out before terminating. " "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." ) # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. if sig == signal.SIGINT: test_case.assertTrue( re.search("KeyboardInterrupt", process.stderr.read()), - "{signal_message_info}: Expected 'KeyboardInterrupt' in log output.", + "Did not find expected string 'KeyboardInterrupt' in log output.", )