Skip to content

Commit

Permalink
prevent autocast on unsupported devices; simplify AsyncTorchModule
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Sep 19, 2024
1 parent 198f156 commit 47e22a5
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions ldp/graph/async_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
_TORCH_LOCK = asyncio.Lock()


def _get_autocast_context(dtype: torch.dtype | None, device_type):
def _get_autocast_context(dtype: torch.dtype | None, device_type: str):
return (
nullcontext()
if dtype is None
if dtype is None and device_type not in {"cpu", "cuda"}
else torch.autocast(dtype=dtype, device_type=device_type)
)

Expand Down Expand Up @@ -130,13 +130,15 @@ async def __call__(self, **kwargs):
# Our request was fulfilled by this or another coroutine!
return self._result_buffer.pop(request_id)

# Try to run a batch
await self._batched_call()
# Try to run a batch. To be safe, set _TORCH_LOCK to prevent other
# coroutines from messing with torch state while running.
async with _TORCH_LOCK:
self._batched_call()

# Sleep, to let another coroutine take over if it needs to
await asyncio.sleep(0.0)

async def _batched_call(self):
def _batched_call(self):
now = time.time()

# sort by oldest requests first
Expand All @@ -155,20 +157,11 @@ async def _batched_call(self):
sample_kwargs = [x[2] for x in batch]
batch_kwargs = self.collate_fn(sample_kwargs)

# Wrap the forward call to be async-safe using the options we want
# Call the module and store results
dtype, device = self._get_dtype_and_device()
protected_call = async_protect_torch_call(
self.module,
module_call_fn=self.module_call_fn,
no_grad=True,
autocast_dtype=dtype,
autocast_device_type=device.type,
)
with torch.no_grad(), _get_autocast_context(dtype, device.type):
batched_results = self.module_call_fn(self.module, **batch_kwargs)

# Call the module and store results
batched_results = await protected_call(
**batch_kwargs,
)
request_ids = [x[1] for x in batch]
results = self.decollate_fn(batched_results)
self._result_buffer.update(zip(request_ids, results, strict=True))

0 comments on commit 47e22a5

Please sign in to comment.