Skip to content

Commit 0232232

Browse files
authored
Merge pull request #7 from sbartlett97/1-fix-train-loop
Fix training loop and other tweaks
2 parents b388211 + 5bef388 commit 0232232

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from torch.utils.data import DataLoader
55
from transformers import AutoModelForSequenceClassification, AutoTokenizer
66
from datasets import load_dataset
7-
from roll_to_train import DnDTrainer
7+
from roll_to_train import RollToTrain
8+
89

910
def main(intelligence=15, dc=12.0, dataset=None):
1011
model_name = "bert-base-uncased"
@@ -17,11 +18,11 @@ def main(intelligence=15, dc=12.0, dataset=None):
1718
optimizer = AdamW(model.parameters(), lr=5e-5)
1819
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
1920

20-
trainer = DnDTrainer(model, tokenizer, optimizer, scheduler, intelligence, float(dc))
21-
trainer.train(dataloader, val_dataloader, steps=len(dataloader), eval_steps=100)
22-
trainer = DnDTrainer(model, tokenizer, optimizer, scheduler, intelligence, float(dc),
21+
trainer = RollToTrain(model, tokenizer, optimizer, scheduler, intelligence, float(dc))
22+
trainer.train(dataloader, val_dataloader)
23+
trainer = RollToTrain(model, tokenizer, optimizer, scheduler, intelligence, float(dc),
2324
mode="per_accumulation_step")
24-
trainer.train(dataloader, val_dataloader, steps=len(dataloader), eval_steps=100)
25+
trainer.train(dataloader, val_dataloader)
2526

2627

2728
if __name__=="__main__":

roll_to_train.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
print(f"Using device: {device}")
1313

1414
# D&D Trainer Class
15-
class DnDTrainer:
15+
class RollToTrain:
1616
"""Main trainer class"""
1717
def __init__(self, model, tokenizer, optimizer, lr_scheduler, intelligence=10, dc=15,
18-
accumulation_steps=64, mode="per_mini_batch"):
18+
accumulation_steps=64, mode="per_mini_batch", num_epochs=3):
1919
self.model = model.to(device)
2020
self.tokenizer = tokenizer
2121
self.optimizer = optimizer
@@ -33,8 +33,8 @@ def __init__(self, model, tokenizer, optimizer, lr_scheduler, intelligence=10, d
3333
self._grad_accum_counter = 0
3434
self._accumulated_loss = 0
3535
self._mode = mode
36-
self.step = 0
37-
self.steps = 0
36+
self.epoch = 0
37+
self.epochs = num_epochs
3838

3939
def roll_d20(self):
4040
"""Roll a D20 dice on the GPU."""
@@ -71,7 +71,6 @@ def weight_update(self, loss):
7171

7272
self._accumulated_loss += loss.item()
7373

74-
# Perform optimization step after accumulation
7574
if self._grad_accum_counter >= self.accumulation_steps:
7675
print("Performing optimizer step after gradient accumulation")
7776
self._loss_history.append(self._accumulated_loss / self.accumulation_steps)
@@ -88,17 +87,16 @@ def weight_update(self, loss):
8887
self._accumulated_loss = 0
8988
self._grad_accum_counter = 0
9089

91-
def train(self, train_dataloader, eval_dataloader, steps=3, eval_steps=100):
90+
def train(self, train_dataloader, eval_dataloader):
9291
"""Train the model for a specified number of steps."""
93-
self.steps = steps
94-
self.step = 0
92+
self.epoch = 0
9593

96-
while self.step < self.steps:
94+
while self.epoch < self.epochs:
9795
for batch_idx, batch in enumerate(train_dataloader):
9896
if self.model.eval:
9997
self.model.train()
10098

101-
print(f"Step {self.step + 1}, Batch {batch_idx + 1}")
99+
print(f"Step {self.epoch + 1}, Batch {batch_idx + 1}")
102100
inputs = self.tokenizer(batch["text"], padding=True, truncation=True,
103101
return_tensors="pt", max_length=512).to(device)
104102
labels = batch["label"].to(device)
@@ -109,13 +107,13 @@ def train(self, train_dataloader, eval_dataloader, steps=3, eval_steps=100):
109107

110108
self.weight_update(loss)
111109

112-
if self.step >= self.steps:
110+
if self.epoch >= self.epochs:
113111
break
114112

115-
if (self.step + 1) % eval_steps == 0:
116-
self.evaluate(eval_dataloader)
117-
113+
self.evaluate(eval_dataloader)
118114
self.lr_scheduler.step()
115+
self.epoch += 1
116+
self.plot_loss(len(train_dataloader))
119117

120118
def evaluate(self, eval_dataloader):
121119
"""Evaluate the model on the validation set."""
@@ -134,7 +132,7 @@ def evaluate(self, eval_dataloader):
134132
self._eval_loss_history.append(avg_loss)
135133
print(f"Evaluation Loss: {avg_loss:.4f}")
136134

137-
def plot_loss(self):
135+
def plot_loss(self, steps):
138136
"""Plot and save the training and evaluation loss."""
139137
fig, axes = plt.subplots(3, 1, figsize=(10, 20), sharex=True)
140138

@@ -145,19 +143,22 @@ def plot_loss(self):
145143
axes[0].grid(True, linestyle='--', alpha=0.7)
146144

147145
# Loss After Roll
148-
axes[1].plot(self._modified_loss_history, color='green', linestyle='-', marker='x')
146+
modified_loss_steps = [i for i in range(steps)] if self._mode == "per_mini_batch" else [i for i in range(0, steps,
147+
self.accumulation_steps)]
148+
axes[1].plot(modified_loss_steps, self._modified_loss_history, color='green', linestyle='-', marker='x')
149149
axes[1].set_title('Loss After Roll')
150150
axes[1].set_ylabel('Loss')
151151
axes[1].grid(True, linestyle='--', alpha=0.7)
152152

153153
# Evaluation Loss
154-
axes[2].plot(self._eval_loss_history, color='red', linestyle='-', marker='s')
154+
eval_steps = [i for i in range(0, steps*self.epochs, steps)]
155+
axes[2].plot(eval_steps, self._eval_loss_history, color='red', linestyle='-', marker='s')
155156
axes[2].set_title('Evaluation Loss')
156157
axes[2].set_xlabel('Training Steps')
157158
axes[2].set_ylabel('Loss')
158159
axes[2].grid(True, linestyle='--', alpha=0.7)
159160

160161
# Save the figure
161162
plt.tight_layout()
162-
plt.savefig("roll_to_train_loss_subplots.png")
163-
print("Saved loss plots as 'roll_to_train_loss_subplots.png'")
163+
plt.savefig(f"{self._mode}_roll_to_train_loss_subplots.png")
164+
print(f"Saved loss plots as '{self._mode}_roll_to_train_loss_subplots.png'")

0 commit comments

Comments
 (0)