-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstory_telling_nn.py
140 lines (115 loc) · 3.57 KB
/
story_telling_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 31 10:03:36 2025
@author: andrey
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pickle
from Tokenizer import Tokenizer
from TextDataset import TextDataset
from NextWordPredictor import NextWordPredictor
from collate_fn import collate_fn
from LabelSmoothingLoss import LabelSmoothingLoss
from my_trainer import my_trainer
from seeders import seeders
from final_text import final_text
# Flags for optional features
use_adamw = True # Use AdamW instead of Adam
alternate_costs = True # Apply Label Smoothing
train_the_model = True
load_model = True
# Network Settings
batch_size = 10 * 4 * 16
embed_size = 3 * 512
hidden_size = 4 * 256
ff_hidden_size = 4 * 256
num_ff_layers = 4
dropout = 0.00
# Training settings
lr_ce = 0.0001 * 0.25 * 0.025 * 0.25
lr_ls = 0.0001 * 0.25
smoothing = 0.005
nepochs = 1000
# Predictor settings
num_words = 30
validate_after_nepochs = 1
# Paths and constants
path = "/gpfs/work/vlasenko/07/NN/Darwin/"
predicted_steps = 1
load_epoch = 9
checkpoint_path = f"{path}/story_telling_final-{predicted_steps}_ep_{load_epoch}.pth"
# Load and preprocess text corpus
with open(f"{path}Darwin_biogr_list_large", "rb") as fp:
corpus = pickle.load(fp)
tokenizer = Tokenizer()
preprocessed_corpus = [tokenizer.preprocess_text(line) for line in corpus]
tokenizer.fit_on_texts(preprocessed_corpus)
total_words = len(tokenizer.word_index) + 1
# Create dataset and DataLoader
dataset = TextDataset(corpus, tokenizer)
train_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NextWordPredictor(
vocab_size=total_words,
embed_size=embed_size,
hidden_size=hidden_size,
ff_hidden_size=ff_hidden_size,
num_ff_layers=num_ff_layers,
predict_steps=predicted_steps,
dropout=dropout
)
criterion_ls = LabelSmoothingLoss(classes=total_words, smoothing=smoothing)
criterion_ce = nn.CrossEntropyLoss()
# --- Optimizer (AdamW or Adam) ---
if use_adamw:
optimizer_ce = torch.optim.AdamW(model.parameters(), lr=lr_ce)
optimizer_ls = torch.optim.AdamW(model.parameters(), lr=lr_ls)
else:
optimizer_ce = torch.optim.Adam(model.parameters(), lr=lr_ce)
optimizer_ls = torch.optim.Adam(model.parameters(), lr=lr_ls)
# Load model checkpoint if required
if load_model:
print("Loading checkpoint: ", checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint) #FIXME Apparently the model does not load optimizer settings. This must be fixed by next commit
start_epoch = checkpoint.get("epoch", 0) + 1
else:
start_epoch = 0 # Train from scratch
model.to(device)
if train_the_model:
my_trainer(
nepochs,
path,
alternate_costs,
criterion_ce,
criterion_ls,
optimizer_ce,
optimizer_ls,
model,
train_loader,
tokenizer,
device=device,
predicted_steps=1,
validate_after_nepochs=validate_after_nepochs,
seeders=seeders,
start_epoch = start_epoch
)
else:
assert load_model, "To validate only you must set load_model to 'True' and specify the loading checkpoint!"
for index, seeder in enumerate(seeders, start=1):
text = final_text(
seeder,
model,
tokenizer,
num_words=100,
device=device
)
print(f"{index}: {text[0]}")