diff --git a/ucp/continuous_ucx_progress.py b/ucp/continuous_ucx_progress.py index 2bf27c85a..2ea96f443 100644 --- a/ucp/continuous_ucx_progress.py +++ b/ucp/continuous_ucx_progress.py @@ -25,10 +25,6 @@ def __init__(self, worker, event_loop): self.event_loop = event_loop self.asyncio_task = None - def __del__(self): - if self.asyncio_task is not None: - self.asyncio_task.cancel() - # Hash and equality is based on the event loop def __hash__(self): return hash(self.event_loop) @@ -83,6 +79,11 @@ def _fd_reader_callback(self): # Notice, we can safely overwrite `self.dangling_arm_task` # since previous arm task is finished by now. assert self.asyncio_task is None or self.asyncio_task.done() + from .core import has_context_referrers + + if not has_context_referrers(): + self.asyncio_task = None + return self.asyncio_task = self.event_loop.create_task(self._arm_worker()) async def _arm_worker(self): diff --git a/ucp/core.py b/ucp/core.py index 6f5ddf3c0..20d8836ea 100644 --- a/ucp/core.py +++ b/ucp/core.py @@ -205,7 +205,7 @@ class ApplicationContext: """ def __init__(self, config_dict={}, blocking_progress_mode=None): - self.progress_tasks = [] + self.progress_tasks = dict() # For now, a application context only has one worker self.context = ucx_api.UCXContext(config_dict) @@ -407,7 +407,7 @@ def continuous_ucx_progress(self, event_loop=None): task = BlockingMode(self.worker, loop, self.epoll_fd) else: task = NonBlockingMode(self.worker, loop) - self.progress_tasks.append(task) + self.progress_tasks[loop] = task def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h) @@ -926,6 +926,14 @@ def init(options={}, env_takes_precedence=False, blocking_progress_mode=None): _ctx = ApplicationContext(options, blocking_progress_mode=blocking_progress_mode) +def has_context_referrers(): + global _ctx + if _ctx is not None: + weakref_ctx = weakref.ref(_ctx) + # gc.collect() + return weakref_ctx() is None + + def reset(): """Resets the UCX library by shutting down all of UCX.