-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsentence_permutation.py
26 lines (23 loc) · 979 Bytes
/
sentence_permutation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import math
from dataclasses import dataclass
from typing import Dict
import jax.numpy as jnp
import nltk
import numpy as np
from jax import ops, random
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from data_collator import DataCollatorForSentencePermutation, DataCollatorForTextInfilling, SentenceTokenize
example = {"text": " My dog is cute. It loves to play in the park. There are many parks in SF."}
sent_tok = SentenceTokenize()
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
permuate_sent = DataCollatorForSentencePermutation(tokenizer)
example = sent_tok(example)
print(example["text"])
out = permuate_sent(tokenizer(example["text"], add_special_tokens=False))
example["text"] = tokenizer.decode(out["input_ids"])
print(example["text"])
masking = DataCollatorForTextInfilling(tokenizer)
out = masking(out)
example["text"] = tokenizer.decode(out["input_ids"][0])
print(example["text"])