Skip to content

[Oneshot refactor] Refactor initialize_model_from_path #1109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 11, 2025
96 changes: 54 additions & 42 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,8 @@ def parse_args(**kwargs):

def initialize_model_from_path(
model_args: ModelArguments,
training_args: TrainingArguments,
training_args: Optional[TrainingArguments] = None,
):
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Load pretrained model
# The .from_pretrained methods guarantee that only one local process can
# concurrently download model & vocab.
Expand All @@ -182,38 +181,70 @@ def initialize_model_from_path(
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,

last_checkpoint = None
teacher = None

if training_args is not None:
# Load teacher configuration if applicable
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
else None
)
if model_args.distill_teacher
else None
)

# Detect last checkpoint
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)

# Set seed before initializing model
set_seed(training_args.seed)

# Initialize teacher model if teacher path is provided
if model_args.distill_teacher is not None:
teacher_device_map = (
None
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else "auto"
)
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

teacher = AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

model_path = (
last_checkpoint or model_args.model
if hasattr(model_args, "model")
else model_args.model_name_or_path
)

# Set seed before initializing model.
set_seed(training_args.seed)

# Fallback to CPU if GPU requested and not available
training_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)

# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings
device_map = None

fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
if not fsdp_enabled and training_args.do_oneshot:
device_map = training_args.oneshot_device
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
elif not fsdp_enabled:

device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"

model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
Expand All @@ -223,15 +254,7 @@ def initialize_model_from_path(
"device_map": device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
teacher_device_map = None if fsdp_enabled else "auto"
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}

# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
Expand All @@ -247,18 +270,7 @@ def initialize_model_from_path(
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]

teacher = (
AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if model_args.distill_teacher is not None
else None
)
if teacher is not None and "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]

return teacher, model_path, model
return model, teacher


def initialize_processor_from_path(
Expand Down Expand Up @@ -357,7 +369,7 @@ def main(

model = model_args.model
if isinstance(model, str) or isinstance(model, PosixPath):
(teacher, _model_path, model) = initialize_model_from_path(
model, teacher = initialize_model_from_path(
model_args,
training_args,
)
Expand Down