@@ -53,83 +53,84 @@ def fit(
53
53
best_loss = np .inf
54
54
early_stopping_counter = 0
55
55
epoch_loss = {}
56
- with trange (
56
+ epoch_pbar = trange (
57
57
start_epoch , n_epochs , desc = "Epochs" , ncols = 100 , disable = disable_pbar , leave = True
58
- ) as epoch_pbar :
59
- for epoch in range (start_epoch , n_epochs ):
60
- epoch_start_time = time .time ()
61
- callbacks .on_epoch_begin (epoch = epoch + 1 )
62
-
63
- epoch_loss .update ({"train_loss" : 0.0 })
64
- train_batch_metrics = Metrics .empty ()
65
-
66
- for batch_idx in range (train_steps_per_epoch ):
67
- callbacks .on_train_batch_begin (batch = batch_idx )
68
-
69
- inputs , labels = next (batch_train_ds )
70
- train_batch_metrics , batch_loss , state = train_step (
71
- state , inputs , labels , train_batch_metrics
72
- )
73
-
74
- epoch_loss ["train_loss" ] += batch_loss
75
- callbacks .on_train_batch_end (batch = batch_idx )
76
-
77
- epoch_loss ["train_loss" ] /= train_steps_per_epoch
78
- epoch_loss ["train_loss" ] = float (epoch_loss ["train_loss" ])
79
-
80
- epoch_metrics = {
81
- f"train_{ key } " : float (val )
82
- for key , val in train_batch_metrics .compute ().items ()
83
- }
84
-
85
- if val_ds is not None :
86
- epoch_loss .update ({"val_loss" : 0.0 })
87
- val_batch_metrics = Metrics .empty ()
88
- for batch_idx in range (val_steps_per_epoch ):
89
- inputs , labels = next (batch_val_ds )
90
-
91
- val_batch_metrics , batch_loss = val_step (
92
- state .params , inputs , labels , val_batch_metrics
93
- )
94
- epoch_loss ["val_loss" ] += batch_loss
95
-
96
- epoch_loss ["val_loss" ] /= val_steps_per_epoch
97
- epoch_loss ["val_loss" ] = float (epoch_loss ["val_loss" ])
58
+ )
59
+ for epoch in range (start_epoch , n_epochs ):
60
+ epoch_start_time = time .time ()
61
+ callbacks .on_epoch_begin (epoch = epoch + 1 )
98
62
99
- epoch_metrics .update (
100
- {
101
- f"val_{ key } " : float (val )
102
- for key , val in val_batch_metrics .compute ().items ()
103
- }
104
- )
63
+ epoch_loss .update ({"train_loss" : 0.0 })
64
+ train_batch_metrics = Metrics .empty ()
105
65
106
- epoch_metrics .update ({** epoch_loss })
66
+ for batch_idx in range (train_steps_per_epoch ):
67
+ callbacks .on_train_batch_begin (batch = batch_idx )
107
68
108
- epoch_end_time = time .time ()
109
- epoch_metrics .update ({"epoch_time" : epoch_end_time - epoch_start_time })
69
+ inputs , labels = next (batch_train_ds )
70
+ train_batch_metrics , batch_loss , state = train_step (
71
+ state , inputs , labels , train_batch_metrics
72
+ )
110
73
111
- ckpt = {"model" : state , "epoch" : epoch }
112
- if epoch % ckpt_interval == 0 :
113
- ckpt_manager .save_checkpoint (ckpt , epoch , latest_dir )
74
+ epoch_loss ["train_loss" ] += batch_loss
75
+ callbacks .on_train_batch_end (batch = batch_idx )
114
76
115
- if epoch_metrics ["val_loss" ] < best_loss :
116
- best_loss = epoch_metrics ["val_loss" ]
117
- ckpt_manager .save_checkpoint (ckpt , epoch , best_dir )
118
- early_stopping_counter = 0
119
- else :
120
- early_stopping_counter += 1
77
+ epoch_loss ["train_loss" ] /= train_steps_per_epoch
78
+ epoch_loss ["train_loss" ] = float (epoch_loss ["train_loss" ])
121
79
122
- callbacks .on_epoch_end (epoch = epoch , logs = epoch_metrics )
80
+ epoch_metrics = {
81
+ f"train_{ key } " : float (val )
82
+ for key , val in train_batch_metrics .compute ().items ()
83
+ }
123
84
124
- epoch_pbar .set_postfix (val_loss = epoch_metrics ["val_loss" ])
125
- epoch_pbar .update ()
85
+ if val_ds is not None :
86
+ epoch_loss .update ({"val_loss" : 0.0 })
87
+ val_batch_metrics = Metrics .empty ()
88
+ for batch_idx in range (val_steps_per_epoch ):
89
+ inputs , labels = next (batch_val_ds )
126
90
127
- if patience is not None and early_stopping_counter >= patience :
128
- log .info (
129
- "Early stopping patience exceeded. Stopping training after"
130
- f" { epoch } epochs."
91
+ val_batch_metrics , batch_loss = val_step (
92
+ state .params , inputs , labels , val_batch_metrics
131
93
)
132
- break
94
+ epoch_loss ["val_loss" ] += batch_loss
95
+
96
+ epoch_loss ["val_loss" ] /= val_steps_per_epoch
97
+ epoch_loss ["val_loss" ] = float (epoch_loss ["val_loss" ])
98
+
99
+ epoch_metrics .update (
100
+ {
101
+ f"val_{ key } " : float (val )
102
+ for key , val in val_batch_metrics .compute ().items ()
103
+ }
104
+ )
105
+
106
+ epoch_metrics .update ({** epoch_loss })
107
+
108
+ epoch_end_time = time .time ()
109
+ epoch_metrics .update ({"epoch_time" : epoch_end_time - epoch_start_time })
110
+
111
+ ckpt = {"model" : state , "epoch" : epoch }
112
+ if epoch % ckpt_interval == 0 :
113
+ ckpt_manager .save_checkpoint (ckpt , epoch , latest_dir )
114
+
115
+ if epoch_metrics ["val_loss" ] < best_loss :
116
+ best_loss = epoch_metrics ["val_loss" ]
117
+ ckpt_manager .save_checkpoint (ckpt , epoch , best_dir )
118
+ early_stopping_counter = 0
119
+ else :
120
+ early_stopping_counter += 1
121
+
122
+ callbacks .on_epoch_end (epoch = epoch , logs = epoch_metrics )
123
+
124
+ epoch_pbar .set_postfix (val_loss = epoch_metrics ["val_loss" ])
125
+ epoch_pbar .update ()
126
+
127
+ if patience is not None and early_stopping_counter >= patience :
128
+ log .info (
129
+ "Early stopping patience exceeded. Stopping training after"
130
+ f" { epoch } epochs."
131
+ )
132
+ break
133
+ epoch_pbar .close ()
133
134
callbacks .on_train_end ()
134
135
135
136
0 commit comments