Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiFilipeCampos committed Jan 3, 2024
1 parent 0e7f390 commit b13a51e
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion gpt_shakespear/train_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ def zero_grad(optimizer: torch.optim.Optimizer) -> Iterator[None]:
# Save the model and log the loss

mlflow.pytorch.log_model(nanoGPT, f"nanogpt_{epoch}")
mlflow.log_metric("loss", loss.item(), epoch)
mlflow.log_metric("loss/train", loss.item(), epoch)
with torch.no_grad():
in_sequence_bw, out_sequence_bwv = generate_batch(val_ids)
pred_logits_btw = pred_logits_bwt.transpose(-1, -2)
val_loss = loss_function(pred_logits_btw, out_sequence_bw)
mlflow.log_metric("loss/val", val_loss.item(), epoch)

mlflow.log_metric("epoch", epoch, epoch)


Expand Down

0 comments on commit b13a51e

Please sign in to comment.