From 47e22a5c479b626486e4d07b6c44eaf0e1d44140 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Thu, 19 Sep 2024 19:57:52 +0000 Subject: [PATCH] prevent autocast on unsupported devices; simplify AsyncTorchModule --- ldp/graph/async_torch.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/ldp/graph/async_torch.py b/ldp/graph/async_torch.py index a8da0b13..b7aa8987 100644 --- a/ldp/graph/async_torch.py +++ b/ldp/graph/async_torch.py @@ -21,10 +21,10 @@ _TORCH_LOCK = asyncio.Lock() -def _get_autocast_context(dtype: torch.dtype | None, device_type): +def _get_autocast_context(dtype: torch.dtype | None, device_type: str): return ( nullcontext() - if dtype is None + if dtype is None and device_type not in {"cpu", "cuda"} else torch.autocast(dtype=dtype, device_type=device_type) ) @@ -130,13 +130,15 @@ async def __call__(self, **kwargs): # Our request was fulfilled by this or another coroutine! return self._result_buffer.pop(request_id) - # Try to run a batch - await self._batched_call() + # Try to run a batch. To be safe, set _TORCH_LOCK to prevent other + # coroutines from messing with torch state while running. + async with _TORCH_LOCK: + self._batched_call() # Sleep, to let another coroutine take over if it needs to await asyncio.sleep(0.0) - async def _batched_call(self): + def _batched_call(self): now = time.time() # sort by oldest requests first @@ -155,20 +157,11 @@ async def _batched_call(self): sample_kwargs = [x[2] for x in batch] batch_kwargs = self.collate_fn(sample_kwargs) - # Wrap the forward call to be async-safe using the options we want + # Call the module and store results dtype, device = self._get_dtype_and_device() - protected_call = async_protect_torch_call( - self.module, - module_call_fn=self.module_call_fn, - no_grad=True, - autocast_dtype=dtype, - autocast_device_type=device.type, - ) + with torch.no_grad(), _get_autocast_context(dtype, device.type): + batched_results = self.module_call_fn(self.module, **batch_kwargs) - # Call the module and store results - batched_results = await protected_call( - **batch_kwargs, - ) request_ids = [x[1] for x in batch] results = self.decollate_fn(batched_results) self._result_buffer.update(zip(request_ids, results, strict=True))