Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot refactor] Refactor initialize_model_from_path #1109

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 55 additions & 42 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import warnings
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from transformers import (
Expand Down Expand Up @@ -162,9 +163,8 @@ def parse_args(**kwargs):

def initialize_model_from_path(
model_args: ModelArguments,
training_args: TrainingArguments,
training_args: Optional[TrainingArguments] = None,
horheynm marked this conversation as resolved.
Show resolved Hide resolved
):
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Load pretrained model
# The .from_pretrained methods guarantee that only one local process can
# concurrently download model & vocab.
Expand All @@ -177,38 +177,70 @@ def initialize_model_from_path(
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,

last_checkpoint = None
teacher = None

if training_args is not None:
# Load teacher configuration if applicable
teacher_config = (
horheynm marked this conversation as resolved.
Show resolved Hide resolved
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
else None
)
if model_args.distill_teacher
else None
)

# Detect last checkpoint
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)

# Set seed before initializing model
set_seed(training_args.seed)
horheynm marked this conversation as resolved.
Show resolved Hide resolved

# Initialize teacher model if teacher path is provided
if model_args.distill_teacher is not None:
teacher_device_map = (
None
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else "auto"
)
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

teacher = AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

model_path = (
last_checkpoint or model_args.model
if hasattr(model_args, "model")
else model_args.model_name_or_path
)

# Set seed before initializing model.
set_seed(training_args.seed)

# Fallback to CPU if GPU requested and not available
training_args.oneshot_device = fallback_to_cpu(training_args.oneshot_device)
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings
device_map = None

fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
if not fsdp_enabled and training_args.do_oneshot:
device_map = training_args.oneshot_device
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
elif not fsdp_enabled:

device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"

model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
Expand All @@ -218,15 +250,7 @@ def initialize_model_from_path(
"device_map": device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
teacher_device_map = None if fsdp_enabled else "auto"
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
Expand All @@ -242,18 +266,7 @@ def initialize_model_from_path(
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]

teacher = (
AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if model_args.distill_teacher is not None
else None
)
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

return teacher, model_path, model
return model, teacher


def initialize_processor_from_path(
Expand Down Expand Up @@ -348,7 +361,7 @@ def main(

model = model_args.model
if isinstance(model, str) or isinstance(model, PosixPath):
(teacher, _model_path, model) = initialize_model_from_path(
model, teacher = initialize_model_from_path(
model_args,
training_args,
)
Expand Down
Loading