Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Cannot load a model with flan-t5 embeddings #3581

Closed
pganesh opened this issue Dec 7, 2024 · 3 comments
Closed

[Bug]: Cannot load a model with flan-t5 embeddings #3581

pganesh opened this issue Dec 7, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@pganesh
Copy link

pganesh commented Dec 7, 2024

Describe the bug

  • Trained a model successfully using the "google/flan-t5-large" transformerembedding
  • Saved the model OK
  • tried to use the saved model by loading using SequenceTagger.load
  • This fails with an error: TypeError, not a string

To Reproduce

import flair
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import WordEmbeddings, StackedEmbeddings, FlairEmbeddings, TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer


# define columns
columns = {0: 'text', 1: 'ner'}

# this is the folder in which train, test and dev files reside
data_folder = './'

# set seed so the train dev and test are same each time
flair.set_seed(42)

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file=data_file,
                              #dev_file='dev.txt',
                              #,
                              #test_file='test.txt',
                              #dev_file='dev.txt'
                              )

# what label do we want to predict?
label_type = 'ner'

#  make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict)

# initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large', # change this to flan-t5 or xlm-roberta-xl. Could not git xl into GPU and flan failed to load after training
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )
# initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  #mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )
### Use the trained model to predict

from flair.data import Sentence
from flair.models import SequenceTagger
# load the trained model
model = SequenceTagger.load('/content/resources/taggers/sota-ner-flert/final-model.pt')

Expected behavior

The traqined model should load OK so it can be used to make predictions.

This exact code works OK if I used the xlm-roberta-large model instead of the google/flan-t5-large
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large', #
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
)

Logs and Stack traces

Here's the error trace:

TypeError Traceback (most recent call last)
in <cell line: 2>()

----> model = SequenceTagger.load(final_model_name)

19 frames
/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in load(cls, model_path)
923 from typing import cast
924
--> 925 return cast("SequenceTagger", super().load(model_path=model_path))

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path)
562 from typing import cast
563
--> 564 return cast("Classifier", super().load(model_path=model_path))
565
566

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in load(cls, model_path)
195 state.pop("cls")
196
--> 197 model = cls._init_model_with_state_dict(state)
198
199 if "model_card" in state:

/usr/local/lib/python3.10/dist-packages/flair/models/sequence_tagger_model.py in _init_model_with_state_dict(cls, state, **kwargs)
622 del state["state_dict"]["transitions"]
623
--> 624 return super()._init_model_with_state_dict(
625 state,
626 embeddings=state.get("embeddings"),

/usr/local/lib/python3.10/dist-packages/flair/nn/model.py in _init_model_with_state_dict(cls, state, **kwargs)
98 embeddings = kwargs.pop("embeddings")
99 if isinstance(embeddings, dict):
--> 100 embeddings = load_embeddings(embeddings)
101 kwargs["embeddings"] = embeddings
102

/usr/local/lib/python3.10/dist-packages/flair/embeddings/base.py in load_embeddings(params)
229 cls_name = params.pop("cls")
230 cls = EMBEDDING_CLASSES[cls_name]
--> 231 return cls.load_embedding(params)

/usr/local/lib/python3.10/dist-packages/flair/embeddings/base.py in load_embedding(cls, params)
95 state_dict = params.pop("state_dict", None)
96
---> 97 embedding = cls.from_params(params)
98 if state_dict is not None:
99 embedding.load_state_dict(state_dict)

/usr/local/lib/python3.10/dist-packages/flair/embeddings/transformer.py in from_params(cls, params)
1326 config_class = CONFIG_MAPPING[model_type]
1327 config = config_class.from_dict(config_state_dict)
-> 1328 return cls.create_from_state(saved_config=config, **params)
1329
1330 def to_params(self):

/usr/local/lib/python3.10/dist-packages/flair/embeddings/token.py in create_from_state(cls, **state)
58 # this parameter is fixed
59 del state["is_token_embedding"]
---> 60 return cls(**state)
61
62

/usr/local/lib/python3.10/dist-packages/flair/embeddings/token.py in init(self, model, is_document_embedding, allow_long_sentences, **kwargs)
45 **kwargs: Arguments propagated to :meth:flair.embeddings.transformer.TransformerEmbeddings.__init__
46 """
---> 47 TransformerEmbeddings.init(
48 self,
49 model=model,

/usr/local/lib/python3.10/dist-packages/flair/embeddings/transformer.py in init(self, model, fine_tune, layers, layer_mean, subtoken_pooling, cls_pooling, is_token_embedding, is_document_embedding, allow_long_sentences, use_context, respect_document_boundaries, context_dropout, saved_config, tokenizer_data, feature_extractor_data, name, force_max_length, needs_manual_ocr, use_context_separator, transformers_tokenizer_kwargs, transformers_config_kwargs, transformers_model_kwargs, peft_config, peft_gradient_checkpointing_kwargs, **kwargs)
1062 else:
1063 # load tokenizer from inmemory zip-file
-> 1064 self.tokenizer = self._tokenizer_from_bytes(tokenizer_data)
1065 if feature_extractor_data is not None:
1066 self.feature_extractor = self._feature_extractor_from_bytes(feature_extractor_data)

/usr/local/lib/python3.10/dist-packages/flair/embeddings/transformer.py in _tokenizer_from_bytes(cls, zip_data)
428 with tempfile.TemporaryDirectory() as temp_dir:
429 zip_obj.extractall(temp_dir)
--> 430 return AutoTokenizer.from_pretrained(temp_dir, add_prefix_space=True)
431
432 @classmethod

/usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py in from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
918 f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
919 )
--> 920 return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
921
922 # Otherwise we have to be creative.

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, trust_remote_code, *init_inputs, **kwargs)
2211 logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
2212
-> 2213 return cls._from_pretrained(
2214 resolved_vocab_files,
2215 pretrained_model_name_or_path,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in _from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, trust_remote_code, *init_inputs, **kwargs)
2445 # Instantiate the tokenizer.
2446 try:
-> 2447 tokenizer = cls(*init_inputs, **init_kwargs)
2448 except import_protobuf_decode_error():
2449 logger.info(

/usr/local/lib/python3.10/dist-packages/transformers/models/t5/tokenization_t5_fast.py in init(self, vocab_file, tokenizer_file, eos_token, unk_token, pad_token, extra_ids, additional_special_tokens, add_prefix_space, **kwargs)
117 kwargs["from_slow"] = True
118
--> 119 super().init(
120 vocab_file,
121 tokenizer_file=tokenizer_file,

/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py in init(self, *args, **kwargs)
130 elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
131 # We need to create and convert a slow tokenizer to build the backend
--> 132 slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
133 fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
134 elif not slow_tokenizer:

/usr/local/lib/python3.10/dist-packages/transformers/models/t5/tokenization_t5.py in init(self, vocab_file, eos_token, unk_token, pad_token, extra_ids, additional_special_tokens, sp_model_kwargs, legacy, add_prefix_space, **kwargs)
148
149 self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
--> 150 self.sp_model.Load(vocab_file)
151
152 if additional_special_tokens is not None:

/usr/local/lib/python3.10/dist-packages/sentencepiece/init.py in Load(self, model_file, model_proto)
959 if model_proto:
960 return self.LoadFromSerializedProto(model_proto)
--> 961 return self.LoadFromFile(model_file)
962
963

/usr/local/lib/python3.10/dist-packages/sentencepiece/init.py in LoadFromFile(self, arg)
314
315 def LoadFromFile(self, arg):
--> 316 return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
317
318 def _EncodeAsIds(self, text, enable_sampling, nbest_size, alpha, add_bos, add_eos, reverse, emit_unk_piece):

TypeError: not a string

Screenshots

No response

Additional Context

TLDR:
Training and using the trained model works great if using xlm-roberta-large
Training works OK when using lan-t5-large but loading the trained model fails with a TypeError

I need some help on how to use the saved/fine-tuned model if the original embedding was based on flan-t5

Environment

Versions:

Flair

0.14.0

Pytorch

2.5.1+cu121

Transformers

4.46.3

GPU

True

@pganesh pganesh added the bug Something isn't working label Dec 7, 2024
@helpmefindaname helpmefindaname changed the title [Bug]: Error when loading a trained model: TypeError: not a string [Bug]: Cannot load a model with flan-t5 embeddings Dec 20, 2024
@helpmefindaname
Copy link
Collaborator

Notice that the reproducing script doesn't reproduce the error, I think the issue is that the line embeddings = TransformerWordEmbeddings(model='xlm-roberta-large', # change this to flan-t5 or xlm-roberta-xl. Could not git xl into GPU and flan failed to load after training doesn't specify loading a flan-t5 model.

I managed to minify the script a bit and reproduce the error with the following:

from pathlib import Path

from flair.data import Dictionary
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger

embeddings = TransformerWordEmbeddings(
    model="google/flan-t5-small",
    layers="-1",
    subtoken_pooling="first",
    fine_tune=False,
)

tagger = SequenceTagger(embeddings, Dictionary(), "ner")
save_path = Path("flan-t5.pt")
tagger.save(save_path)
del tagger
loaded_tagger = SequenceTagger.load(save_path)

it is important to notice, that the bug requires a newer version of transformers. It doesn't occur with transformers==4.30.2 and it requires protobuf<3.20.0 to be installed

@stefan-it
Copy link
Member

stefan-it commented Dec 20, 2024

I could also reproduce this. My observations so far:

  • Use a T5 model without Fast Tokenizer, that means only the spiece.model file is located on the model hub -> Training is fine
  • Whenever there's a Fast Tokenizer version available, it fails with that error message. FlanT5 has Fast Tokenizer variant, whereas e.g. google/t5-v1_1-base does not have a fast variant uploaded.

In the Fast Tokenizer case it helped to remove the add_prefix_space=True option:

model, add_prefix_space=True, **transformers_tokenizer_kwargs, **kwargs

But then you can't use the non-fast variant.

@helpmefindaname
Copy link
Collaborator

helpmefindaname commented Dec 20, 2024

this is already fixed via #3544 updating to flair==0.15.0 (today's release) solves the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants