Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot refactor] Refactor initialize_model_from_path #1109

Merged
merged 9 commits into from
Feb 11, 2025
88 changes: 48 additions & 40 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import warnings
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from transformers import (
Expand Down Expand Up @@ -162,9 +163,8 @@ def parse_args(**kwargs):

def initialize_model_from_path(
model_args: ModelArguments,
training_args: TrainingArguments,
training_args: Optional[TrainingArguments] = None,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
):
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Load pretrained model
# The .from_pretrained methods guarantee that only one local process can
# concurrently download model & vocab.
Expand All @@ -177,38 +177,42 @@ def initialize_model_from_path(
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,

last_checkpoint = None

if training_args is not None:
teacher_config = (
horheynm marked this conversation as resolved.
Show resolved Hide resolved
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
else None
)
if model_args.distill_teacher
else None
)
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Set seed before initializing model.
set_seed(training_args.seed)
horheynm marked this conversation as resolved.
Show resolved Hide resolved

model_path = (
last_checkpoint or model_args.model
if hasattr(model_args, "model")
else model_args.model_name_or_path
)

# Set seed before initializing model.
set_seed(training_args.seed)

# Fallback to CPU if GPU requested and not available
training_args.oneshot_device = fallback_to_cpu(training_args.oneshot_device)
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings
device_map = None

fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
if not fsdp_enabled and training_args.do_oneshot:
device_map = training_args.oneshot_device
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
elif not fsdp_enabled:

device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"

model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
Expand All @@ -218,15 +222,7 @@ def initialize_model_from_path(
"device_map": device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
teacher_device_map = None if fsdp_enabled else "auto"
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
Expand All @@ -242,18 +238,30 @@ def initialize_model_from_path(
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]

teacher = (
AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
teacher = None
if training_args is not None:
teacher_device_map = None if fsdp_enabled else "auto"
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

teacher = (
AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if model_args.distill_teacher is not None
else None
)
if model_args.distill_teacher is not None
else None
)
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

return teacher, model_path, model
return model, teacher
dsikka marked this conversation as resolved.
Show resolved Hide resolved


def initialize_processor_from_path(
Expand Down Expand Up @@ -348,7 +356,7 @@ def main(

model = model_args.model
if isinstance(model, str) or isinstance(model, PosixPath):
(teacher, _model_path, model) = initialize_model_from_path(
model, teacher = initialize_model_from_path(
model_args,
training_args,
)
Expand Down
Loading