Skip to content

Commit

Permalink
ThunderModule - fix load_original_state_dict to work with modules wit…
Browse files Browse the repository at this point in the history
…h buffers (#648)
  • Loading branch information
kshitij12345 authored Jun 25, 2024
1 parent 14e6c9b commit 56b04dd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thunder/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_buffer(self, name):
return self._model.get_buffer(name)

def set_buffer(self, name, value):
p = self._overrides_buffers[name] = value
self._overrides_buffers[name] = value

def get_parameter(self, name):
p = self._overrides_parameters.get(name, self._null)
Expand Down Expand Up @@ -117,7 +117,7 @@ def load_original_state_dict(self, state_dict):
full_k = prefix + k
if k in self._overrides_parameters:
self._overrides_parameters[full_k] = v
elif k in model._overrides_buffers:
elif k in self._overrides_buffers:
self._overrides_buffers[full_k] = v
else:
raise NotImplementedError(f"don't know how to handle {full_k}")
Expand Down
21 changes: 21 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,24 @@ def foo(dev, idx):
actual = jfoo("cuda", 0)

assert_close(expected, actual)


def test_load_original_state_dict():
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("param", torch.nn.Parameter(torch.randn(3)))
self.register_buffer("buffer", torch.randn(3))

def forward(self, x):
return x

m = Model()

thunder_module = thunder.jit(Model())
thunder_module.load_original_state_dict(m.state_dict())

# Check the updated values
# We can't directly compare state_dict - https://github.com/Lightning-AI/lightning-thunder/issues/647
torch.testing.assert_close(thunder_module._overrides_parameters["param"], m.param)
torch.testing.assert_close(thunder_module._overrides_buffers["buffer"], m.buffer)

0 comments on commit 56b04dd

Please sign in to comment.