|
7 | 7 | from thunder.core import utils
|
8 | 8 | from thunder.core import prims
|
9 | 9 | import torch
|
| 10 | +import math |
10 | 11 |
|
11 | 12 | from .utils import (
|
12 | 13 | get_orig_and_thunder_module_proxies_from_prologue,
|
@@ -294,3 +295,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
|
294 | 295 |
|
295 | 296 | new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
|
296 | 297 | return prologue_trace, new_computation_trace, epilogue_trace
|
| 298 | + |
| 299 | + |
| 300 | +class LORATransform(Transform): |
| 301 | + def __init__( |
| 302 | + self, |
| 303 | + r: int = 0, |
| 304 | + lora_alpha: int = 1, |
| 305 | + lora_dropout: float = 0.0, |
| 306 | + **kwargs, |
| 307 | + ): |
| 308 | + self.r = r |
| 309 | + self.lora_alpha = lora_alpha |
| 310 | + self.lora_dropout = lora_dropout |
| 311 | + self.lora_linear_names = set() |
| 312 | + |
| 313 | + def lora_linear(self, x): |
| 314 | + in_features, out_features = x.shape[0], x.shape[1] |
| 315 | + |
| 316 | + if self.lora_dropout > 0.0: |
| 317 | + dropout = torch.nn.Dropout(p=self.lora_dropout) |
| 318 | + else: |
| 319 | + dropout = lambda x: x |
| 320 | + |
| 321 | + linear = torch.nn.Linear(in_features, out_features) |
| 322 | + lora_A = torch.nn.Parameter(torch.empty((self.r, in_features))) |
| 323 | + lora_B = torch.nn.Parameter(torch.empty((out_features, self.r))) |
| 324 | + torch.nn.init.kaiming_uniform_(lora_A, a=math.sqrt(5)) |
| 325 | + torch.nn.init.zeros_(lora_B) |
| 326 | + scaling = self.lora_alpha / self.r |
| 327 | + |
| 328 | + pretrained = linear(x) |
| 329 | + lora = (dropout(x) @ lora_A.transpose(0, 1) @ lora_B.transpose(0, 1)) * scaling |
| 330 | + return pretrained + lora |
| 331 | + |
| 332 | + def transform_module(self, model: thunder.ThunderModule): |
| 333 | + self.thunder_module = model |
| 334 | + shared_names = model._get_shared_names() |
| 335 | + processed_names = set() |
| 336 | + |
| 337 | + def convert_linear_submodule(tm, name): |
| 338 | + self.lora_linear_names.add(name) |
| 339 | + weight_name = f"{name}.weight" |
| 340 | + processed_copies = shared_names[weight_name] & processed_names |
| 341 | + if processed_copies: |
| 342 | + copy_name = next(iter(processed_copies)) |
| 343 | + tm._overrides_parameters[weight_name] = tm._overrides_parameters[copy_name] |
| 344 | + |
| 345 | + w = tm.get_parameter(weight_name) |
| 346 | + qw = self.lora_linear(w) |
| 347 | + tm._overrides_parameters[weight_name] = qw.to(w.device) |
| 348 | + processed_copies.add(weight_name) |
| 349 | + |
| 350 | + for n, submodule in model._model.named_modules(): |
| 351 | + if isinstance(submodule, torch.nn.Linear): |
| 352 | + convert_linear_submodule(model, n) |
| 353 | + |
| 354 | + def transform_state_dict_for_submodule( |
| 355 | + self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict |
| 356 | + ) -> dict: |
| 357 | + if submodule_name not in self.lora_linear_names: |
| 358 | + return state_dict |
| 359 | + |
| 360 | + weight_name_full = f"{submodule_name}.weight" |
| 361 | + w = state_dict["weight"] |
| 362 | + qw = self.lora_linear(w) |
| 363 | + |
| 364 | + state_dict = state_dict.copy() |
| 365 | + state_dict["weight"] = qw.to(w.device) |
| 366 | + |
| 367 | + return state_dict |
| 368 | + |
| 369 | + def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): |
| 370 | + return prologue_trace, computation_trace, epilogue_trace |
0 commit comments