diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 61e6441bb..a01b990a9 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -20,6 +20,7 @@ import os import warnings from pathlib import PosixPath +from typing import Optional from loguru import logger from transformers import ( @@ -51,7 +52,7 @@ patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_shared_processor_src, + get_processor_name_from_model, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -257,10 +258,13 @@ def initialize_model_from_path( def initialize_processor_from_path( - model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel + model_args: ModelArguments, + model: PreTrainedModel, + teacher: Optional[PreTrainedModel] = None, ) -> Processor: - processor_src = model_args.processor - processor_src = processor_src or get_shared_processor_src(model, teacher) + processor_src = model_args.processor or get_processor_name_from_model( + model, teacher + ) # The use_fast=True option is not currently supported safely in Transformers # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 try: diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index d7abc323a..74e7666db 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -7,7 +7,7 @@ __all__ = [ "SparseAutoModelForCausalLM", - "get_shared_processor_src", + "get_processor_name_from_model", ] @@ -20,7 +20,7 @@ def from_pretrained(*args, **kwargs): return AutoModelForCausalLM.from_pretrained(*args, **kwargs) -def get_shared_processor_src(student: Module, teacher: Optional[Module]) -> str: +def get_processor_name_from_model(student: Module, teacher: Optional[Module]) -> str: """ Get a processor/tokenizer source used for both student and teacher, assuming that they could be shared