@@ -56,7 +56,6 @@ class TrainConfig(BaseConfig):
56
56
57
57
58
58
class MonitorConfig (BaseConfig ):
59
- enable_monitor : bool = False
60
59
log_flush_interval : int = 10
61
60
base_url : str | None = None
62
61
auth_token : str | None = None
@@ -76,7 +75,7 @@ class Config(BaseConfig):
76
75
data : DataConfig = DataConfig ()
77
76
optim : OptimConfig = OptimConfig ()
78
77
train : TrainConfig
79
- monitor : MonitorConfig = MonitorConfig ()
78
+ monitor : MonitorConfig | None = None
80
79
81
80
82
81
def train (config : Config ):
@@ -167,8 +166,9 @@ def train(config: Config):
167
166
logger_cls = WandbMetricLogger if config .metric_logger_type == "wandb" else DummyMetricLogger
168
167
metric_logger = logger_cls (project = config .project , config = config .model_dump (), resume = False )
169
168
170
- monitor = HttpMonitor (config = config .model_dump (), resume = False )
171
- monitor .set_stage ("init" )
169
+ if config .monitor is not None :
170
+ monitor = HttpMonitor (config = config .model_dump (), resume = False )
171
+ monitor .set_stage ("init" )
172
172
173
173
train_dataloader_iterator = iter (train_dataloader )
174
174
@@ -182,7 +182,7 @@ def train(config: Config):
182
182
# if we don't use diloco we don't print the outer step logs
183
183
logger .info (f"outer_step step: { outer_step } " )
184
184
185
- if world_info .rank == 0 :
185
+ if world_info .rank == 0 and config . monitor is not None :
186
186
monitor .set_stage ("inner_loop" )
187
187
188
188
for inner_step in range (num_inner_steps ):
@@ -245,12 +245,13 @@ def train(config: Config):
245
245
246
246
if world_info .rank == 0 :
247
247
metric_logger .log (metrics )
248
- monitor .log (metrics )
248
+ if config .monitor is not None :
249
+ monitor .log (metrics )
249
250
250
251
logger .info (log )
251
252
252
253
if config .diloco is not None :
253
- if world_info .rank == 0 :
254
+ if world_info .rank == 0 and config . monitor is not None :
254
255
monitor .set_stage ("outer_loop" )
255
256
diloco .step (model )
256
257
@@ -263,8 +264,9 @@ def train(config: Config):
263
264
break
264
265
265
266
if world_info .rank == 0 :
266
- monitor .finish ()
267
267
metric_logger .finish ()
268
+ if config .monitor is not None :
269
+ monitor .finish ()
268
270
269
271
270
272
if __name__ == "__main__" :
0 commit comments