Skip to content

Commit

Permalink
Add deallocation before forward
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Dec 29, 2021
1 parent e52b73a commit 18c89f6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions parallelformers/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,23 @@ def parallelize(self) -> None:
traceback.print_exc()
self.deparallelize()

@staticmethod
def _deallocate(item):
if torch.is_tensor(item) and item.is_cuda:
item.cpu()

elif isinstance(item, list) or isinstance(item, tuple):
for i in item:
if torch.is_tensor(i) and i.is_cuda:
i.cpu()

elif isinstance(item, dict):
for i in item:
if torch.is_tensor(item[i]) and item[i].is_cuda:
item[i].cpu()

return item

@torch.no_grad()
def hijack(
self,
Expand All @@ -314,6 +331,11 @@ def hijack(
self.inference_mutexes,
self.inputs_queues,
):
inputs = self._deallocate(inputs)

for k in kwargs:
kwargs[k] = self._deallocate(kwargs[k])

i_queue.put((inputs, kwargs, func))
i_mutex.set()
# producer part
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

setup(
name='parallelformers',
version='1.2.3',
version='1.2.4',
description=
'An Efficient Model Parallelization Toolkit for Deployment',
long_description=long_description,
Expand Down

0 comments on commit 18c89f6

Please sign in to comment.