Skip to content

Commit b4ce42e

Browse files
committed
feat: rough draft of lora linear
1 parent 3548ba8 commit b4ce42e

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

thunder/tests/test_transforms.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,42 @@ def get_model():
334334

335335
assert jm_ref._get_shared_names()["0.weight"] == {"0.weight", "4.weight"}
336336
assert jm._get_shared_names()["0.weight"] == {"0.weight", "4.weight"}
337+
338+
339+
def test_lora_transform_linear():
340+
from thunder.transforms import LORATransform
341+
342+
DIM = 512
343+
344+
class Model(torch.nn.Module):
345+
def __init__(self) -> None:
346+
super().__init__()
347+
self.fc1 = torch.nn.Linear(DIM, DIM)
348+
self.fc2 = torch.nn.Linear(DIM, DIM)
349+
350+
def forward(self, x):
351+
x = self.fc1(x)
352+
x = torch.nn.functional.relu(x)
353+
x = self.fc2(x)
354+
return x
355+
356+
model = Model()
357+
x = torch.randn(4, DIM)
358+
359+
loratransform = LORATransform(r=16, lora_alpha=32)
360+
361+
jmodel = thunder.jit(
362+
model,
363+
transforms=[
364+
loratransform,
365+
],
366+
)
367+
actual = jmodel(x)
368+
original_jmodel = thunder.jit(model)
369+
expected = original_jmodel(x)
370+
371+
print(thunder.last_traces(original_jmodel)[-1])
372+
print(thunder.last_traces(jmodel)[-1])
373+
374+
assert_close(actual, expected, atol=2e-1, rtol=2e-1)
375+
assert False == True

thunder/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .materialization import MaterializationTransform
2+
from .quantization import LORATransform
23

34

45
__all__ = [
56
"MaterializationTransform",
7+
"LORATransform",
68
]

thunder/transforms/quantization.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from thunder.core import utils
88
from thunder.core import prims
99
import torch
10+
import math
1011

1112
from .utils import (
1213
get_orig_and_thunder_module_proxies_from_prologue,
@@ -294,3 +295,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
294295

295296
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
296297
return prologue_trace, new_computation_trace, epilogue_trace
298+
299+
300+
class LORATransform(Transform):
301+
def __init__(
302+
self,
303+
r: int = 0,
304+
lora_alpha: int = 1,
305+
lora_dropout: float = 0.0,
306+
**kwargs,
307+
):
308+
self.r = r
309+
self.lora_alpha = lora_alpha
310+
self.lora_dropout = lora_dropout
311+
self.lora_linear_names = set()
312+
313+
def lora_linear(self, x):
314+
in_features, out_features = x.shape[0], x.shape[1]
315+
316+
if self.lora_dropout > 0.0:
317+
dropout = torch.nn.Dropout(p=self.lora_dropout)
318+
else:
319+
dropout = lambda x: x
320+
321+
linear = torch.nn.Linear(in_features, out_features)
322+
lora_A = torch.nn.Parameter(torch.empty((self.r, in_features)))
323+
lora_B = torch.nn.Parameter(torch.empty((out_features, self.r)))
324+
torch.nn.init.kaiming_uniform_(lora_A, a=math.sqrt(5))
325+
torch.nn.init.zeros_(lora_B)
326+
scaling = self.lora_alpha / self.r
327+
328+
pretrained = linear(x)
329+
lora = (dropout(x) @ lora_A.transpose(0, 1) @ lora_B.transpose(0, 1)) * scaling
330+
return pretrained + lora
331+
332+
def transform_module(self, model: thunder.ThunderModule):
333+
self.thunder_module = model
334+
shared_names = model._get_shared_names()
335+
processed_names = set()
336+
337+
def convert_linear_submodule(tm, name):
338+
self.lora_linear_names.add(name)
339+
weight_name = f"{name}.weight"
340+
processed_copies = shared_names[weight_name] & processed_names
341+
if processed_copies:
342+
copy_name = next(iter(processed_copies))
343+
tm._overrides_parameters[weight_name] = tm._overrides_parameters[copy_name]
344+
345+
w = tm.get_parameter(weight_name)
346+
qw = self.lora_linear(w)
347+
tm._overrides_parameters[weight_name] = qw.to(w.device)
348+
processed_copies.add(weight_name)
349+
350+
for n, submodule in model._model.named_modules():
351+
if isinstance(submodule, torch.nn.Linear):
352+
convert_linear_submodule(model, n)
353+
354+
def transform_state_dict_for_submodule(
355+
self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict
356+
) -> dict:
357+
if submodule_name not in self.lora_linear_names:
358+
return state_dict
359+
360+
weight_name_full = f"{submodule_name}.weight"
361+
w = state_dict["weight"]
362+
qw = self.lora_linear(w)
363+
364+
state_dict = state_dict.copy()
365+
state_dict["weight"] = qw.to(w.device)
366+
367+
return state_dict
368+
369+
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
370+
return prologue_trace, computation_trace, epilogue_trace

0 commit comments

Comments
 (0)