Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a transformer based encoder-decoder model #3597

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

chelseagzr
Copy link

@chelseagzr chelseagzr commented Jan 14, 2025

This PR is aimed to address this feature I proposed.
Please let me know if you have any suggestions for this model! Thank you!

Example codes to use the model:

from flair.datasets import OpusParallelCorpus
from flair.embeddings import TransformerEmbeddings
from flair.models import CausalLanguageModelDecoder, EncoderDecoderLanguageModel
from flair.trainers import ModelTrainer

# 1) Create a corpus for the Tatoeba dataset.
#    We will learn a model which can translate German to English.
corpus = OpusParallelCorpus(
    dataset="tatoeba",  # "tatoeba" or "subtitles" are supported in this example
    l1="de",
    l2="en",
    max_tokens_per_doc=512,
)


# 2) Create an encoder and a decoder
encoder_embedding = TransformerEmbeddings("Qwen/Qwen2.5-0.5B-Instruct")
decoder = CausalLanguageModelDecoder("gpt2")


# 3) Define input text and output text for the encoder-decoder model
def input_text_fn1(datapoint):
    return datapoint.first.text


def output_text_fn1(datapoint):
    return datapoint.second.text


# 4) Instantiate models
edm1 = EncoderDecoderLanguageModel(
    encoder_embeddings=encoder_embedding,
    decoder=decoder,
    label_type="translate_de_to_en",
    generate_input_text_fn=input_text_fn1,
    generate_output_text_fn=output_text_fn1,
)


trainer = ModelTrainer(edm1, corpus)
trainer.fine_tune("local", max_epochs=1, mini_batch_size=8, eval_batch_size=8, save_model_each_k_epochs=1)

The two functions, generate_input_text_fn and generate_output_text_fn, provide the flexibility to create different tasks without duplicating the underlying dataset. Example codes to demonstrate this flexibility.

from flair.datasets import OpusParallelCorpus
from flair.embeddings import TransformerEmbeddings
from flair.models import CausalLanguageModelDecoder, EncoderDecoderLanguageModel
from flair.nn.multitask import make_multitask_model_and_corpus
from flair.trainers import ModelTrainer

# 1) Create a corpus for the Tatoeba dataset.
#    We will learn a model which can translate English to German and translate German to English
corpus = OpusParallelCorpus(
    dataset="tatoeba",  # "tatoeba" or "subtitles" are supported in this example
    l1="de",
    l2="en",
    max_tokens_per_doc=512,
)


# 2) Create an encoder and a decoder
encoder_embedding = TransformerEmbeddings("Qwen/Qwen2.5-0.5B-Instruct")
decoder = CausalLanguageModelDecoder("gpt2")


# 3) Define input text and output text for the encoder-decoder model
def input_text_fn1(datapoint):
    task_prompt = "[TASK] translate German to English: [TASK] "
    return task_prompt + datapoint.first.text


def output_text_fn1(datapoint):
    return datapoint.second.text


def input_text_fn2(datapoint):
    task_prompt = "[TASK] translate English to German: [TASK] "
    return task_prompt + datapoint.second.text


def output_text_fn2(datapoint):
    return datapoint.first.text


# 4) Instantiate models
edm1 = EncoderDecoderLanguageModel(
    encoder_embeddings=encoder_embedding,
    decoder=decoder,
    label_type="translate_de_to_en",
    generate_input_text_fn=input_text_fn1,
    generate_output_text_fn=output_text_fn1,
)

edm2 = EncoderDecoderLanguageModel(
    encoder_embeddings=encoder_embedding,
    decoder=decoder,
    label_type="translate_en_to_de",
    generate_input_text_fn=input_text_fn2,
    generate_output_text_fn=output_text_fn2,
)


multitask_model, multicorpus = make_multitask_model_and_corpus([(edm1, corpus), (edm2, corpus)])

trainer = ModelTrainer(multitask_model, multicorpus)
trainer.fine_tune("local", max_epochs=1, mini_batch_size=8, eval_batch_size=8, save_model_each_k_epochs=1)

return tied_weights


PreTrainedModel._tie_encoder_decoder_weights = _tie_encoder_decoder_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the shape check really necessary?
I would prefer to need as little patching of external functions as possible, as that function is always subject of change, increasing the bourdon of maintainance.
Maybe it would be an better idea to make a PR to transformers and add the shapecheck there?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I think it's good to have the shape check, but I can remove this patch for now. It might be better to make a PR to transformers.

func_definition (str): the definition of a single function
"""
local_scope = {}
exec(func_definition, local_scope)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows remote code execution and is something that we should be avoided.

I think it would be fair to always expect a SentencePair as input and add an optional parameter prompt_template.
Then you can set prompt_template="[TASK] translate German to English: [TASK] {text}" for the opus example.
Is that enough?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback!
In the second example above, I believe an additional parameter (use_first_as_input) is required to train a model to translate between English and German using a single underlying corpus. This parameter specifies which sentence in the sentence pair should be treated as the input and which as the output.
Here are some code to illustrate the use of the additional parameter use_first_as_input:

class EncoderDecoderLanguageModel(Model, GenerationMixin):
    """A language model based on an encoder-decoder architecture using HuggingFace's Transformers."""

    label_pad_token_id: int = -100  # The index to ignore when calculating cross_entropy loss

    def __init__(
        self,
        encoder_embeddings: Any,
        decoder: CausalLanguageModelDecoder,
        label_type: str,
        use_first_as_input: bool = True,
        task_prompt: str = "",
        tie_encoder_decoder: bool = False,
    ) -> None:
        super().__init__()

        self.encoder_embeddings = encoder_embeddings
        self.decoder = decoder

        self._label_type = label_type
        self._use_first_as_input = use_first_as_input
        self._task_prompt = task_prompt
        self.tie_encoder_decoder = tie_encoder_decoder

        # Initialize EncoderDecoderModel
        ......

    def forward_loss(self, datapoints: list[DataPoint]) -> tuple[torch.Tensor, int]:
        if len(datapoints) == 0:
            raise RuntimeError("No datapoints provided")

        if self._use_first_as_input:
            input_texts = [self._task_prompt + s.first.text for s in datapoints]
            target_texts = [s.second.text for s in datapoints]
        else:
            input_texts = [self._task_prompt + s.second.text for s in datapoints]
            target_texts = [s.first.text for s in datapoints]

        encoder_inputs = self.encoder_tokenizer(
            input_texts,
            padding="longest",
            truncation=True,
            return_tensors="pt",
        ).to(flair.device)

        ...... 

(similar for evaluate and predict)

Then, we instantiate tasks by providing use_first_as_input and task_prompt:

edm1 = EncoderDecoderLanguageModel(
    encoder_embeddings=encoder_embedding,
    decoder=decoder,
    label_type="translate_de_to_en",
    use_first_as_input=True,
    task_prompt="[TASK] translate German to English: [TASK] "
)

edm2 = EncoderDecoderLanguageModel(
    encoder_embeddings=encoder_embedding,
    decoder=decoder,
    label_type="translate_en_to_de",
    use_first_as_input=False,
    task_prompt="[TASK] translate English to German: [TASK] "
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt_template and use_first_as_input parameters should be sufficient to replace the generate_input_text_fn and generate_output_text_fn functions in the examples I provided earlier. However, I’m considering using the functions for something beyond simple string concatenation.
For example, I store a list of strings in the _metadata attribute of a datapoint and I want to use the concatenation of these strings as the input for the encoder-decoder model. To make the model invariant to the order of the strings in the input, I want to shuffle the list each time the datapoint is used.

Assume I have a list of strings under the key "ingredients" in self._metadata of a datapoint:
self._metadata = {"ingredients": ["apple", "banana", "pear"]}
For each epoch, I want to see different permutations of the concatenated strings, such as "apple banana pear", "banana apple pear", "pear apple banana", etc.
With the generate_input_text_fn function, this could be achieved by defining:

def input_text_fn(datapair):
    ingredients = datapair.first.get_metadata("ingredients").copy()
    random.shuffle(ingredients)
    return " ".join(ingredients)

Do you have any suggestions on how to accommodate this use case without saving and loading functions? Thank you!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think designwise it is not necessary to have use_first_as_input as you can just create different corpora:

corpus_de_en = OpusParallelCorpus(
    dataset="tatoeba",  # "tatoeba" or "subtitles" are supported in this example
    l1="de",
    l2="en",
    max_tokens_per_doc=512,
)

corpus_en_de = OpusParallelCorpus(
    dataset="tatoeba",  # "tatoeba" or "subtitles" are supported in this example
    l1="en",
    l2="de",
    max_tokens_per_doc=512,
)

When you want to do data-augmentation like with the metadata, you should build your dataset in that way. E.g.:'

class DataAugmentatedDataset(FlairDataset):
   def __init__(self, orig_dataset: FlairDataset) -> None:
       self.orig_dataset = orig_datset
       
   def __len__(self) -> int:
       return len(self.orig_dataset)
       
   def __getitem__(i: int) -> DataPoint:
       return self.transform(self.orig_dataset[i])
       
   def transform(dp: dataPoint) -> DataPoint:
        ....

As you wouldn't want that data-augmentation on inference (and also not in the test set).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the response! I'll remove generate_input_text_fn and generate_output_text_fn and only add prompt_template.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants