Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Load Planner #169

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions diffusion/planners/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Composer checkpointing planners."""

from diffusion.planners.lora_planner import LoraPlanner

__all__ = ['LoraPlanner']
58 changes: 58 additions & 0 deletions diffusion/planners/lora_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""LoRA Planner."""
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata

__all__ = ['LoraPlanner']


class LoraPlanner(DefaultLoadPlanner):
"""Takes a Composer checkpoint and converts it to LoRA Checkpoint."""

def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
is_coordinator: bool,
) -> None:
"""Sets up the planner for converting Composer to LoRA Checkpoint.

Takes all targeted modules and checks whether they have been LoRA processed. If not,
changes names of weights appropriately. If yes, doesn't change anything for autoresume
compatibility.

Args:
state_dict (STATE_DICT_TYPE): Original torch state dict.
metadata (METADATA): Any metadata associated with the state dict.
is_coordinator (bool): Whether the machine this is running on is the coordinator of loading.
"""
if 'state' not in state_dict:
super().set_up_planner(state_dict, metadata, is_coordinator)
return

self.original_state_dict = state_dict

state_dict = dict(state_dict.items())
state_dict['state'] = dict(state_dict['state'].items())
target_modules = ['to_k', 'to_v', 'to_q', 'to_out.0']

for key in state_dict['state']['model'].keys():
for mod in target_modules:
if f'{mod}.weight' in key:
new_key = key.replace(mod, mod + '.base_layer')
state_dict['state']['model'][new_key] = state_dict['state']['model'].pop(key)
break

if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)

if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)

self.state_dict = state_dict
self.metadata = metadata
self.is_coordinator = is_coordinator
32 changes: 21 additions & 11 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
from diffusion.planners import LoraPlanner


def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer:
Expand Down Expand Up @@ -206,19 +207,28 @@ def train(config: DictConfig) -> None:
print(f'Instantiating callbacks <{call_conf._target_}>')
callbacks.append(hydra.utils.instantiate(call_conf))

if 'fsdp_config' in config.trainer:
fsdp_config = dict(config.trainer.fsdp_config)
config.trainer.__delattr__("fsdp_config")
else:
fsdp_config = None

if 'lora_rank' in config.model:
assert fsdp_config is not None
fsdp_config['load_planner'] = LoraPlanner

scheduler = hydra.utils.instantiate(config.scheduler)

trainer: Trainer = hydra.utils.instantiate(
config.trainer,
train_dataloader=train_dataloader,
eval_dataloader=eval_set,
optimizers=optimizer,
model=model,
loggers=logger,
algorithms=algorithms,
schedulers=scheduler,
callbacks=callbacks,
)
trainer: Trainer = hydra.utils.instantiate(config.trainer,
train_dataloader=train_dataloader,
eval_dataloader=eval_set,
optimizers=optimizer,
model=model,
loggers=logger,
algorithms=algorithms,
schedulers=scheduler,
callbacks=callbacks,
fsdp_config=fsdp_config)

def eval_and_then_train():
if config.get('eval_first', True):
Expand Down