Skip to content

Commit

Permalink
Fix align_module_device, ensure only cpu tensors for `get_state_dic…
Browse files Browse the repository at this point in the history
…t_offloaded_model` (#3217)

* only onload direct parameter descendants, move buffers to cpu, add tests

* remove no longer applicable comment
  • Loading branch information
kylesayrs authored Nov 5, 2024
1 parent bf4572b commit c0552c9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ def get_state_dict_offloaded_model(model: nn.Module):
placeholders.add(name + f".{key}")
continue
params = module_state_dict[key]
state_dict[name + f".{key}"] = params
state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu
for key in placeholders.copy():
if key in state_dict:
placeholders.remove(key)
Expand Down Expand Up @@ -1923,7 +1923,7 @@ def align_module_device(module: torch.nn.Module, execution_device: Optional[torc
module._hf_hook.execution_device = original_device

elif execution_device is not None:
devices = {name: param.device for name, param in module.named_parameters()}
devices = {name: param.device for name, param in module.named_parameters(recurse=False)}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
convert_file_size_to_int,
find_tied_parameters,
get_balanced_memory,
get_state_dict_offloaded_model,
infer_auto_device_map,
load_checkpoint_in_model,
load_state_dict,
Expand All @@ -66,6 +67,15 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class NestedModelForTest(nn.Module):
def __init__(self):
super().__init__()
self.model = ModelForTest()

def forward(self, x):
return self.model(x)


class LinearWithNonPersistentBuffers(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
Expand Down Expand Up @@ -788,6 +798,19 @@ def test_convert_file_size(self):
with self.assertRaises(ValueError):
convert_file_size_to_int("-1GB")

def test_get_state_dict_offloaded_model(self):
for model_cls in (ModelForTest, NestedModelForTest):
model = model_cls()
execution_device = torch.device(torch_device)
original_state_dict = model.state_dict()

cpu_offload(model, execution_device=execution_device)
state_dict = get_state_dict_offloaded_model(model)

assert original_state_dict.keys() == state_dict.keys()
for key in original_state_dict:
assert torch.equal(original_state_dict[key], state_dict[key])

def test_align_module_device_simple(self):
model = ModelForTest()
execution_device = torch.device(torch_device)
Expand Down Expand Up @@ -834,3 +857,13 @@ def test_align_module_device_offloaded(self):
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == offload_device
assert model.linear2.weight.device == offload_device

def test_align_module_device_offloaded_nested(self):
model = NestedModelForTest()
execution_device = torch.device(torch_device)
align_device = torch.device("cpu")
cpu_offload(model, execution_device=execution_device)
for module in model.modules():
with align_module_device(module, align_device):
for param in model.parameters(recurse=False):
assert param.device == align_device

0 comments on commit c0552c9

Please sign in to comment.