@@ -153,7 +153,6 @@ def _cpu_yield(self, check_context: bool = True):
153
153
# It is critical for correctness that only one thread is running
154
154
# at a time. These asserts just make sure that this is the only
155
155
# thread running before waking the other one up and going to sleep
156
- print (f"CPU yield { self .id } { type (forward_context ._forward_context )} { type (self .forward_context )} " )
157
156
assert (
158
157
not check_context or
159
158
forward_context ._forward_context is self .forward_context )
@@ -172,11 +171,6 @@ def switch_to_comm_sync(self):
172
171
self .update_stream (self .comm_stream )
173
172
self ._wait_comm_done ()
174
173
175
- def switch_to_compute_sync (self ):
176
- self ._signal_comm_done ()
177
- self .update_stream (self .compute_stream )
178
- self ._wait_compute_done ()
179
-
180
174
def maybe_run_recv_hook (self ):
181
175
if self .recv_hook is not None :
182
176
self .recv_hook ()
@@ -291,7 +285,6 @@ def make_ubatch_contexts(
291
285
comm_stream : torch .cuda .Stream ,
292
286
forward_contexts : list [ForwardContext ],
293
287
ready_barrier : threading .Barrier ,
294
- device : Optional [torch .device ] = None ,
295
288
schedule : Schedule = Schedule .MLP_OVERLAP ,
296
289
delayed_start : bool = False ,
297
290
) -> list [UBatchContext ]:
0 commit comments