11import  asyncio 
22import  threading 
33from  datetime  import  timedelta 
4- from  typing  import  Optional , TypeVar 
4+ from  typing  import  Callable ,  Optional , TypeVar 
55from  unittest .mock  import  Mock 
66
7+ import  torch 
78from  torch .futures  import  Future 
89
910T  =  TypeVar ("T" )
@@ -17,7 +18,6 @@ def __init__(self) -> None:
1718
1819    def  set_timer (self , timer_handle : asyncio .TimerHandle ) ->  None :
1920        assert  self ._lock .locked ()
20- 
2121        self ._timer_handle  =  timer_handle 
2222        self ._lock .release ()
2323
@@ -99,6 +99,18 @@ def callback(fut: Future[T]) -> None:
9999        fut .add_done_callback (callback )
100100        return  timed_fut 
101101
102+     def  stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) ->  None :
103+         loop  =  self ._maybe_start_event_loop ()
104+ 
105+         event  =  torch .cuda .Event ()
106+         event .record ()
107+ 
108+         def  handler () ->  None :
109+             if  not  event .query ():
110+                 callback ()
111+ 
112+         loop .call_soon_threadsafe (self ._register_handler , loop , handler , timeout )
113+ 
102114    @classmethod  
103115    def  _register (
104116        cls ,
@@ -116,6 +128,18 @@ def _register(
116128        )
117129        handle .set_timer (timer_handle )
118130
131+     @classmethod  
132+     def  _register_handler (
133+         cls ,
134+         loop ,
135+         handler : Callable [[], None ],
136+         timeout : timedelta ,
137+     ) ->  None :
138+         loop .call_later (
139+             timeout .total_seconds (),
140+             handler ,
141+         )
142+ 
119143
120144_TIMEOUT_MANAGER  =  _TimeoutManager ()
121145
@@ -163,3 +187,18 @@ def callback(fut: Future[T]) -> T:
163187        raise  TimeoutError (f"future did not complete within { timeout }  )
164188
165189    return  fut .wait ()
190+ 
191+ 
192+ def  stream_timeout (callback : Callable [[], None ], timeout : timedelta ) ->  None :
193+     """ 
194+     Registers a callback that will be called after the specified timeout if 
195+     the current stream doesn't complete in time. 
196+ 
197+     This uses a cuda Event to track the completion of the current stream. If 
198+     the stream is not complete after the timeout, the callback is called. 
199+ 
200+     Args: 
201+         callback: The callback to call if the stream doesn't complete in time. 
202+         timeout: The timeout to wait for the stream to complete. 
203+     """ 
204+     _TIMEOUT_MANAGER .stream_timeout (callback , timeout )
0 commit comments