-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathtest_language_model.py
115 lines (82 loc) · 4.26 KB
/
test_language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pytest
from flair.data import Dictionary, Sentence
from flair.embeddings import FlairEmbeddings, TokenEmbeddings
from flair.models import LanguageModel
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus
@pytest.mark.integration()
def test_train_language_model(results_base_path, resources_path):
# get default dictionary
dictionary: Dictionary = Dictionary.load("chars")
# init forward LM with 128 hidden states and 1 layer
language_model: LanguageModel = LanguageModel(dictionary, is_forward_lm=True, hidden_size=128, nlayers=1)
# get the example corpus and process at character level in forward direction
corpus: TextCorpus = TextCorpus(
resources_path / "corpora/lorem_ipsum",
dictionary,
language_model.is_forward_lm,
character_level=True,
)
# train the language model
trainer: LanguageModelTrainer = LanguageModelTrainer(language_model, corpus, test_mode=True)
trainer.train(results_base_path, sequence_length=10, mini_batch_size=10, max_epochs=2)
# use the character LM as embeddings to embed the example sentence 'I love Berlin'
char_lm_embeddings: TokenEmbeddings = FlairEmbeddings(str(results_base_path / "best-lm.pt"))
sentence = Sentence("I love Berlin")
char_lm_embeddings.embed(sentence)
text, likelihood = language_model.generate_text(number_of_characters=100)
assert text is not None
assert len(text) >= 100
# clean up results directory
del trainer, language_model, corpus, char_lm_embeddings
@pytest.mark.integration()
def test_train_resume_language_model(resources_path, results_base_path, tasks_base_path):
# get default dictionary
dictionary: Dictionary = Dictionary.load("chars")
# init forward LM with 128 hidden states and 1 layer
language_model: LanguageModel = LanguageModel(dictionary, is_forward_lm=True, hidden_size=128, nlayers=1)
# get the example corpus and process at character level in forward direction
corpus: TextCorpus = TextCorpus(
resources_path / "corpora/lorem_ipsum",
dictionary,
language_model.is_forward_lm,
character_level=True,
)
# train the language model
trainer: LanguageModelTrainer = LanguageModelTrainer(language_model, corpus, test_mode=True)
trainer.train(
results_base_path,
sequence_length=10,
mini_batch_size=10,
max_epochs=2,
checkpoint=True,
)
del trainer, language_model
trainer = LanguageModelTrainer.load_checkpoint(results_base_path / "checkpoint.pt", corpus)
trainer.train(results_base_path, sequence_length=10, mini_batch_size=10, max_epochs=2)
del trainer
def test_generate_text_with_small_temperatures():
from flair.embeddings import FlairEmbeddings
language_model = FlairEmbeddings("news-forward-fast", has_decoder=True).lm
text, likelihood = language_model.generate_text(temperature=0.01, number_of_characters=100)
assert text is not None
assert len(text) >= 100
del language_model
def test_compute_perplexity():
from flair.embeddings import FlairEmbeddings
language_model = FlairEmbeddings("news-forward-fast", has_decoder=True).lm
grammatical = "The company made a profit"
perplexity_gramamtical_sentence = language_model.calculate_perplexity(grammatical)
ungrammatical = "Nook negh qapla!"
perplexity_ungramamtical_sentence = language_model.calculate_perplexity(ungrammatical)
print(f'"{grammatical}" - perplexity is {perplexity_gramamtical_sentence}')
print(f'"{ungrammatical}" - perplexity is {perplexity_ungramamtical_sentence}')
assert perplexity_gramamtical_sentence < perplexity_ungramamtical_sentence
language_model = FlairEmbeddings("news-backward-fast", has_decoder=True).lm
grammatical = "The company made a profit"
perplexity_gramamtical_sentence = language_model.calculate_perplexity(grammatical)
ungrammatical = "Nook negh qapla!"
perplexity_ungramamtical_sentence = language_model.calculate_perplexity(ungrammatical)
print(f'"{grammatical}" - perplexity is {perplexity_gramamtical_sentence}')
print(f'"{ungrammatical}" - perplexity is {perplexity_ungramamtical_sentence}')
assert perplexity_gramamtical_sentence < perplexity_ungramamtical_sentence
del language_model