12
12
print (f"Using device: { device } " )
13
13
14
14
# D&D Trainer Class
15
- class DnDTrainer :
15
+ class RollToTrain :
16
16
"""Main trainer class"""
17
17
def __init__ (self , model , tokenizer , optimizer , lr_scheduler , intelligence = 10 , dc = 15 ,
18
- accumulation_steps = 64 , mode = "per_mini_batch" ):
18
+ accumulation_steps = 64 , mode = "per_mini_batch" , num_epochs = 3 ):
19
19
self .model = model .to (device )
20
20
self .tokenizer = tokenizer
21
21
self .optimizer = optimizer
@@ -33,8 +33,8 @@ def __init__(self, model, tokenizer, optimizer, lr_scheduler, intelligence=10, d
33
33
self ._grad_accum_counter = 0
34
34
self ._accumulated_loss = 0
35
35
self ._mode = mode
36
- self .step = 0
37
- self .steps = 0
36
+ self .epoch = 0
37
+ self .epochs = num_epochs
38
38
39
39
def roll_d20 (self ):
40
40
"""Roll a D20 dice on the GPU."""
@@ -71,7 +71,6 @@ def weight_update(self, loss):
71
71
72
72
self ._accumulated_loss += loss .item ()
73
73
74
- # Perform optimization step after accumulation
75
74
if self ._grad_accum_counter >= self .accumulation_steps :
76
75
print ("Performing optimizer step after gradient accumulation" )
77
76
self ._loss_history .append (self ._accumulated_loss / self .accumulation_steps )
@@ -88,17 +87,16 @@ def weight_update(self, loss):
88
87
self ._accumulated_loss = 0
89
88
self ._grad_accum_counter = 0
90
89
91
- def train (self , train_dataloader , eval_dataloader , steps = 3 , eval_steps = 100 ):
90
+ def train (self , train_dataloader , eval_dataloader ):
92
91
"""Train the model for a specified number of steps."""
93
- self .steps = steps
94
- self .step = 0
92
+ self .epoch = 0
95
93
96
- while self .step < self .steps :
94
+ while self .epoch < self .epochs :
97
95
for batch_idx , batch in enumerate (train_dataloader ):
98
96
if self .model .eval :
99
97
self .model .train ()
100
98
101
- print (f"Step { self .step + 1 } , Batch { batch_idx + 1 } " )
99
+ print (f"Step { self .epoch + 1 } , Batch { batch_idx + 1 } " )
102
100
inputs = self .tokenizer (batch ["text" ], padding = True , truncation = True ,
103
101
return_tensors = "pt" , max_length = 512 ).to (device )
104
102
labels = batch ["label" ].to (device )
@@ -109,13 +107,13 @@ def train(self, train_dataloader, eval_dataloader, steps=3, eval_steps=100):
109
107
110
108
self .weight_update (loss )
111
109
112
- if self .step >= self .steps :
110
+ if self .epoch >= self .epochs :
113
111
break
114
112
115
- if (self .step + 1 ) % eval_steps == 0 :
116
- self .evaluate (eval_dataloader )
117
-
113
+ self .evaluate (eval_dataloader )
118
114
self .lr_scheduler .step ()
115
+ self .epoch += 1
116
+ self .plot_loss (len (train_dataloader ))
119
117
120
118
def evaluate (self , eval_dataloader ):
121
119
"""Evaluate the model on the validation set."""
@@ -134,7 +132,7 @@ def evaluate(self, eval_dataloader):
134
132
self ._eval_loss_history .append (avg_loss )
135
133
print (f"Evaluation Loss: { avg_loss :.4f} " )
136
134
137
- def plot_loss (self ):
135
+ def plot_loss (self , steps ):
138
136
"""Plot and save the training and evaluation loss."""
139
137
fig , axes = plt .subplots (3 , 1 , figsize = (10 , 20 ), sharex = True )
140
138
@@ -145,19 +143,22 @@ def plot_loss(self):
145
143
axes [0 ].grid (True , linestyle = '--' , alpha = 0.7 )
146
144
147
145
# Loss After Roll
148
- axes [1 ].plot (self ._modified_loss_history , color = 'green' , linestyle = '-' , marker = 'x' )
146
+ modified_loss_steps = [i for i in range (steps )] if self ._mode == "per_mini_batch" else [i for i in range (0 , steps ,
147
+ self .accumulation_steps )]
148
+ axes [1 ].plot (modified_loss_steps , self ._modified_loss_history , color = 'green' , linestyle = '-' , marker = 'x' )
149
149
axes [1 ].set_title ('Loss After Roll' )
150
150
axes [1 ].set_ylabel ('Loss' )
151
151
axes [1 ].grid (True , linestyle = '--' , alpha = 0.7 )
152
152
153
153
# Evaluation Loss
154
- axes [2 ].plot (self ._eval_loss_history , color = 'red' , linestyle = '-' , marker = 's' )
154
+ eval_steps = [i for i in range (0 , steps * self .epochs , steps )]
155
+ axes [2 ].plot (eval_steps , self ._eval_loss_history , color = 'red' , linestyle = '-' , marker = 's' )
155
156
axes [2 ].set_title ('Evaluation Loss' )
156
157
axes [2 ].set_xlabel ('Training Steps' )
157
158
axes [2 ].set_ylabel ('Loss' )
158
159
axes [2 ].grid (True , linestyle = '--' , alpha = 0.7 )
159
160
160
161
# Save the figure
161
162
plt .tight_layout ()
162
- plt .savefig ("roll_to_train_loss_subplots .png" )
163
- print ("Saved loss plots as 'roll_to_train_loss_subplots .png'" )
163
+ plt .savefig (f" { self . _mode } _roll_to_train_loss_subplots .png" )
164
+ print (f "Saved loss plots as '{ self . _mode } _roll_to_train_loss_subplots .png'" )
0 commit comments