Skip to content

Commit

Permalink
materialization with reset_parameters (#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 19, 2024
1 parent 54a0371 commit dcf4782
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 26 deletions.
76 changes: 55 additions & 21 deletions thunder/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import collections

import torch as pytorch
from torch.utils.weak import WeakTensorKeyDictionary

import thunder
from thunder.core.compile_data import get_compile_data
Expand Down Expand Up @@ -105,6 +106,18 @@ def named_buffers(self, prefix="", recurse=True, remove_duplicate=True, *, persi
remove_duplicate=remove_duplicate,
)

def _get_shared_names(self):
parameters_to_names = WeakTensorKeyDictionary()
for name, v in itertools.chain(
self.named_parameters(remove_duplicate=False), self.named_buffers(remove_duplicate=False)
):
parameters_to_names.setdefault(v, set()).add(name)
shared_names = {}
for s in parameters_to_names.values():
for n in s:
shared_names[n] = s
return shared_names

def load_original_state_dict(self, state_dict):
# this loads the state dict incrementally to not exhaust memory
module_names = {n for n, _ in self._model.named_modules()}
Expand All @@ -116,32 +129,53 @@ def load_original_state_dict(self, state_dict):
prefix, sep, _ = prefix.rpartition(".")
sd_per_module[prefix][k[len(prefix) + len(sep) :]] = v

shared_names = self._get_shared_names()
processed_names = set()

for submodule_name, sd_part in sd_per_module.items():
prefix = submodule_name + ("." if submodule_name else "")
for transform in self._lc_transforms:
sd_part = transform.transform_state_dict_for_submodule(self, submodule_name, sd_part)
for k, v in sd_part.items():
full_k = prefix + k
if full_k in self._overrides_parameters:
p = self._overrides_parameters[full_k]
if p.dtype == v.dtype and p.shape == v.shape:
with pytorch.no_grad():
p.copy_(v)
else:
with pytorch.no_grad():
self._overrides_parameters[full_k] = pytorch.nn.Parameter(
v.to(p.device), requires_grad=p.requires_grad
)
self._transform_and_load_for_submodule(submodule_name, sd_part, shared_names, processed_names)

def _transform_and_load_for_submodule(self, submodule_name, sd_part, shared_names, processed_names):
prefix = submodule_name + ("." if submodule_name else "")
for transform in self._lc_transforms:
sd_part = transform.transform_state_dict_for_submodule(self, submodule_name, sd_part)

for k, v in sd_part.items():
full_k = prefix + k

# cater for shared parameters
processed_copies = shared_names[full_k] & processed_names
if processed_copies:
copy_name = next(iter(processed_copies))
if full_k in self._overrides_parameters:
self._overrides_parameters[full_k] = self._overrides_parameters[copy_name]
elif full_k in self._overrides_buffers:
if p.dtype == v.dtype and p.shape == v.shape:
with pytorch.no_grad():
self._overrides_buffers[full_k].copy_(v)
else:
with pytorch.no_grad():
self._overrides_parameters[full_k] = v.to(p.device).requires_grad_(p.requires_grad)
self._overrides_buffers[full_k] = self._overrides_buffers[copy_name]
else:
raise NotImplementedError(f"don't know how to handle {full_k}")
processed_names.add(full_k)
continue

if full_k in self._overrides_parameters:
p = self._overrides_parameters[full_k]
if p.dtype == v.dtype and p.shape == v.shape:
with pytorch.no_grad():
p.copy_(v)
else:
with pytorch.no_grad():
self._overrides_parameters[full_k] = pytorch.nn.Parameter(
v.to(p.device), requires_grad=p.requires_grad
)
elif full_k in self._overrides_buffers:
if p.dtype == v.dtype and p.shape == v.shape:
with pytorch.no_grad():
self._overrides_buffers[full_k].copy_(v)
else:
with pytorch.no_grad():
self._overrides_parameters[full_k] = v.to(p.device).requires_grad_(p.requires_grad)
else:
raise NotImplementedError(f"don't know how to handle {full_k}")
processed_names.add(full_k)

def state_dict(self, *, destination=None, prefix="", keep_vars=False):
"""
Expand Down
53 changes: 53 additions & 0 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,56 @@ def test_nvfuser_cse():
assert prologue_proxy.device == thunder.core.devices.to_device(t.device)
assert comp_proxy.dtype == thunder.core.dtypes.to_dtype(t.dtype)
assert prologue_proxy.dtype == thunder.core.dtypes.to_dtype(t.dtype)


@requiresCUDA
def test_materialization_init():
from thunder.transforms import MaterializationTransform
from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit, get_bitsandbytes_executor

bitsandbytes_executor = get_bitsandbytes_executor()

def get_model():
m0 = torch.nn.Linear(2, 2)

# to not change the rng state
with torch.device("meta"):
m4 = torch.nn.Linear(2, 2)

m4.weight = m0.weight
m4.bias = m0.bias

return (
torch.nn.Sequential(
m0,
torch.nn.GELU(),
torch.nn.Linear(2, 2),
torch.nn.GELU(),
m4,
)
.eval()
.requires_grad_(False)
)

torch.manual_seed(1234)
with torch.device("cuda"):
m_ref = get_model()
inp = torch.randn(3, 2)

jm_ref = thunder.jit(m_ref, transforms=[BitsAndBytesLinearQuant4bit()], executors=(bitsandbytes_executor,))

torch.manual_seed(1234)
init_from_module_init = MaterializationTransform.init_from_original_module_init()
with torch.device("meta"):
m = get_model()

jm = thunder.jit(
m,
transforms=[BitsAndBytesLinearQuant4bit(), MaterializationTransform("cuda", init=init_from_module_init)],
executors=(bitsandbytes_executor,),
)

assert_close(jm(inp), jm_ref(inp))

assert jm_ref._get_shared_names()["0.weight"] == {"0.weight", "4.weight"}
assert jm._get_shared_names()["0.weight"] == {"0.weight", "4.weight"}
95 changes: 90 additions & 5 deletions thunder/transforms/materialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations
from collections.abc import Callable
import copy
from itertools import chain
from typing import TYPE_CHECKING

import torch
from torch.utils.weak import WeakTensorKeyDictionary

from thunder.core.transform_common import Transform

Expand All @@ -15,9 +18,18 @@ class MaterializationTransform(Transform):
Args:
device: Device to host :class:`~thunder.core.module.ThunderModule` after materialization.
The transform will annotate any unannotated parameters on the meta device as to be initialized on this device.
Keyword Args:
init: Post-processing callable applied to :class:`~thunder.core.module.ThunderModule` after materialization.
possible values are obtained from
- `MaterializationTransform.init_from_original_state_dict(state_dict)`
populate weights from a `state_dict` from the untransformed module,
- `MaterializationTransform.init_from_transformed_state_dict(state_dict)`
populate weights from a `state_dict` from the transformed module,
- `MaterializationTransform.init_from_original_module_init()`
initialize using the weight initialization of the original module (`reset_parameters`)
"""

def __init__(
Expand All @@ -30,14 +42,30 @@ def __init__(
self.init = init

def transform_module(self, model: ThunderModule):
for n, p in model.named_parameters():
for p in chain(model._model.parameters(), model._model.buffers()):
if p.device.type == "meta" and not hasattr(p, "_thunder_device"):
p._thunder_device = self.device

shared_names = model._get_shared_names()

# note: the iterations below are without duplicates
for n, p in list(model.named_parameters()):
if p.device.type == "meta":
model._overrides_parameters[n] = torch.nn.Parameter(
torch.empty_like(p, device=self.device), requires_grad=p.requires_grad
p = torch.nn.Parameter(
torch.empty_like(p, device=getattr(p, "_thunder_device", self.device)),
requires_grad=p.requires_grad,
)
for n, b in model.named_buffers():
for nn in shared_names[n]:
model._overrides_parameters[nn] = p

for n, b in list(model.named_buffers()):
if b.device.type == "meta":
model._overrides_buffers[n] = torch.empty_like(b, device=self.device, requires_grad=b.requires_grad)
b = torch.empty_like(
b, device=getattr(b, "_thunder_device", self.device), requires_grad=b.requires_grad
)
for nn in shared_names[n]:
model._overrides_parameters[nn] = b

self.init(self, model)

@staticmethod
Expand All @@ -55,3 +83,60 @@ def module_init_from_transformed_state_dict(transform: MaterializationTransform,
model.load_state_dict(state_dict)

return module_init_from_transformed_state_dict

@staticmethod
def init_from_original_module_init():
def module_init_from_original_module_init(transform: MaterializationTransform, tm: ThunderModule):

shared_names = tm._get_shared_names()
processed_names = set()

# Shared parameters in PyTorch eager are parameters of module which have different name but share the underlying tensor.
# For shared parameter, we replace all occurence shared parameter with its corresponding `base` parameter.
# In our implementation `base` parameter is the parameter and corresponding name which we see the first time while
# iterating our parameters (see below). We track subsequent parameter which share the underlying Tensor with this `base` parameter
# in `shared_params_name` dictionary.

for module_name, _ in tm._model.named_modules():
prefix = module_name if not module_name else f"{module_name}."
submodule = tm.get_submodule(module_name)

# we use a copy to let the user's module alone
module_copy = copy.copy(submodule)

# Materialize meta-parameters on-device if necessary.
# This is done before sharding in case the materialization logic depends on the tensor shape.
# The tradeoff is that all of a module's direct parameters need to fit in device.
# Each module only initializes its own parameters and not those of its children (recurse=False)
if any(
t.is_meta for t in chain(module_copy.parameters(recurse=False), module_copy.buffers(recurse=False))
):
# we need to initialize the module unless all parameters are duplicatess
need_init = not all(
shared_names[n] & processed_names
for n, _ in chain(
module_copy.named_parameters(prefix=module_name, recurse=False),
module_copy.named_buffers(prefix=module_name, recurse=False),
)
)

if need_init:
# TODO: we could also support calling a "param_init_fn" argument like PyTorch
module_copy.to_empty(device=transform.device, recurse=False)
if not hasattr(module_copy, "reset_parameters"):
raise TypeError(
f"Materialization requires that the `{type(submodule).__name__}.reset_parameters` method is implemented."
" This method is used to initialize any children parameters or buffers in this module."
)
module_copy.reset_parameters()

# TODO: non-persistent buffers?
sd = {
n: p
for n, p in chain(
module_copy.named_parameters(recurse=False), module_copy.named_buffers(recurse=False)
)
}
tm._transform_and_load_for_submodule(module_name, sd, shared_names, processed_names)

return module_init_from_original_module_init
13 changes: 13 additions & 0 deletions thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,22 @@ def quantize_weight(self, w):

def transform_module(self, model: thunder.ThunderModule):
self.thunder_module = model
shared_names = model._get_shared_names()
processed_names = set()

def convert_linear_submodule(tm, name):
self.quantized_submodule_names.add(name)
weight_name = f"{name}.weight"
processed_copies = shared_names[weight_name] & processed_names
if processed_copies:
copy_name = next(iter(processed_copies))
self.quant_states[weight_name] = self.quant_states[copy_name]
tm._overrides_parameters[weight_name] = tm._overrides_parameters[copy_name]
tm._overrides_parameters[f"{weight_name}.absmax"] = tm._overrides_parameters[f"{copy_name}.absmax"]
tm._overrides_parameters[f"{weight_name}.code"] = tm._overrides_parameters[f"{copy_name}.code"]
processed_names.add(weight_name)
return

w = tm.get_parameter(weight_name)
# TODO: double quant support

Expand All @@ -127,6 +139,7 @@ def convert_linear_submodule(tm, name):
"code.shape": tuple(qs.code.shape),
"device": getattr(w, "_thunder_device", w.device),
}
processed_names.add(weight_name)

for n, submodule in model._model.named_modules():
if isinstance(submodule, torch.nn.Linear):
Expand Down

0 comments on commit dcf4782

Please sign in to comment.