Skip to content

Commit 806d144

Browse files
committed
linter
1 parent c49ff1a commit 806d144

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+572
-360
lines changed

black.toml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[tool.black]
2+
line-length = 120
3+
target-version = ['py311']
4+
extend-exclude = '''
5+
(
6+
.*/.*\.pyi
7+
| .*/.*_grpc\.py
8+
| .*/.*_pb2\.py
9+
| .*/generated/
10+
| .*/build/
11+
| .*/libbuild/
12+
| .*/clang-tidy-build/
13+
)
14+
'''

fmengine/cli/main.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def train(config: str = typer.Option(..., help="Path to the config file")):
1818
raise ValueError(f"Config file not found: {config}")
1919
# if it is a directory, search for yaml files
2020
if os.path.isdir(config):
21-
config_files = [
22-
os.path.join(config, f) for f in os.listdir(config) if f.endswith(".yaml")
23-
]
21+
config_files = [os.path.join(config, f) for f in os.listdir(config) if f.endswith(".yaml")]
2422
typer.echo(f"config files found: {config_files}")
2523
configs = [OmegaConf.load(f) for f in config_files]
2624
config = OmegaConf.merge(*configs)

fmengine/cli/trainer.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import contextlib
12
import os
23
import time
34

45
import humanize
56
import torch
6-
import contextlib
77
from torch.distributed.elastic.multiprocessing.errors import record
88
from torch.fx import GraphModule
99

10+
from fmengine.core.checkpoint import CheckpointManager, TrainState
1011
from fmengine.core.configs.train_config import TrainJobConfig
1112
from fmengine.core.nn import build_lr_scheduler, build_optimizer
1213
from fmengine.core.nn.loss import cross_entropy_loss
@@ -15,9 +16,8 @@
1516
from fmengine.models.builder import build_model
1617
from fmengine.models.llama.modeling_llama import parallelize_llama
1718
from fmengine.models.utils import get_num_params
18-
from fmengine.utilities import (GarbageCollection, build_gpu_memory_monitor,
19-
get_peak_flops, logger)
20-
from fmengine.core.checkpoint import CheckpointManager, TrainState
19+
from fmengine.utilities import GarbageCollection, build_gpu_memory_monitor, get_peak_flops, logger
20+
2121

2222
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
2323
@contextlib.contextmanager
@@ -26,13 +26,12 @@ def context():
2626
if enable_loss_parallel:
2727
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
2828
if enable_compiled_autograd:
29-
stack.enter_context(
30-
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
31-
)
29+
stack.enter_context(torch._dynamo.utils.maybe_enable_compiled_autograd(True))
3230
yield
3331

3432
return context
3533

34+
3635
@record
3736
def train_entry(job_config: TrainJobConfig):
3837
gc_handler = GarbageCollection()
@@ -64,8 +63,7 @@ def train_entry(job_config: TrainJobConfig):
6463
model_param_count = get_num_params(model)
6564
logger.info(f"Model has {humanize.intword(model_param_count)} parameters")
6665
# todo(xiaozhe): pipeline parallelism enabled
67-
parallelize_llama(model, world_mesh, parallel_dims,
68-
train_config=job_config.training)
66+
parallelize_llama(model, world_mesh, parallel_dims, train_config=job_config.training)
6967
init_device = "cuda"
7068
model.to_empty(device=init_device)
7169
model_parts = [model]
@@ -105,5 +103,4 @@ def train_entry(job_config: TrainJobConfig):
105103
)
106104
time.sleep(10000)
107105

108-
109106
torch.distributed.destroy_process_group()

fmengine/core/checkpoint/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .checkpoint import CheckpointManager
22
from .train_state import TrainState
33

4-
__all__ = ["CheckpointManager", "TrainState"]
4+
__all__ = ["CheckpointManager", "TrainState"]

fmengine/core/checkpoint/checkpoint.py

+33-76
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import torch.distributed as dist
1414
import torch.distributed.checkpoint as dcp
1515
import torch.nn as nn
16-
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
17-
get_model_state_dict,
18-
get_optimizer_state_dict,
19-
set_model_state_dict,
20-
set_optimizer_state_dict)
16+
from torch.distributed.checkpoint.state_dict import (
17+
StateDictOptions,
18+
get_model_state_dict,
19+
get_optimizer_state_dict,
20+
set_model_state_dict,
21+
set_optimizer_state_dict,
22+
)
2123
from torch.distributed.checkpoint.stateful import Stateful
2224
from torch.utils.data import DataLoader
2325

@@ -50,9 +52,7 @@ def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
5052
self.model = [model] if isinstance(model, nn.Module) else model
5153

5254
def state_dict(self) -> None:
53-
return {
54-
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
55-
}
55+
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
5656

5757
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
5858
func = functools.partial(
@@ -70,8 +70,7 @@ def __init__(
7070
optim: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]],
7171
) -> None:
7272
self.model = [model] if isinstance(model, nn.Module) else model
73-
self.optim = [optim] if isinstance(
74-
optim, torch.optim.Optimizer) else optim
73+
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
7574

7675
def state_dict(self) -> None:
7776
func = functools.partial(
@@ -109,8 +108,7 @@ def checkpoint_mp(recv, send):
109108
state, checkpoint_id = obj
110109
dcp.save(state, checkpoint_id=checkpoint_id)
111110
logger.info(
112-
"Finish saving the checkpoint in the background process in "
113-
f"{time.monotonic() - begin:.2f} seconds."
111+
"Finish saving the checkpoint in the background process in " f"{time.monotonic() - begin:.2f} seconds."
114112
)
115113
finally:
116114
logger.info("Destroying the process group.")
@@ -158,19 +156,11 @@ def __init__(
158156
159157
TODO: This is currently unsolved and needs a fix.
160158
"""
161-
assert len(model_parts) == len(
162-
optimizers
163-
), "Must pass one optimizer per model part"
164-
assert len(model_parts) == len(
165-
lr_schedulers
166-
), "Must pass one lr_scheduler per model part"
167-
168-
assert len(model_parts) == len(
169-
optimizers
170-
), "Must pass one optimizer per model part"
171-
assert len(model_parts) == len(
172-
lr_schedulers
173-
), "Must pass one lr_scheduler per model part"
159+
assert len(model_parts) == len(optimizers), "Must pass one optimizer per model part"
160+
assert len(model_parts) == len(lr_schedulers), "Must pass one lr_scheduler per model part"
161+
162+
assert len(model_parts) == len(optimizers), "Must pass one optimizer per model part"
163+
assert len(model_parts) == len(lr_schedulers), "Must pass one lr_scheduler per model part"
174164

175165
self.states = states
176166

@@ -190,11 +180,7 @@ def __init__(
190180
self.states[f"lr_scheduler_{idx}"] = lr_scheduler
191181

192182
self.folder = os.path.join(ckpt_config.ckpt_dir)
193-
self.interval_type = (
194-
IntervalType.SECONDS
195-
if ckpt_config.interval_type == "seconds"
196-
else IntervalType.STEPS
197-
)
183+
self.interval_type = IntervalType.SECONDS if ckpt_config.interval_type == "seconds" else IntervalType.STEPS
198184
self.interval = ckpt_config.interval
199185
self.begin_time = 0
200186
self.time_sync_work = None
@@ -231,12 +217,9 @@ def __init__(
231217
self.staging_id = None
232218
self.staging_stream = torch.cuda.Stream()
233219
else:
234-
raise ValueError(
235-
f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
220+
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
236221

237-
logger.info(
238-
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
239-
)
222+
logger.info(f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}")
240223

241224
def __del__(self):
242225
if self.enable_checkpoint and self.mp and self.mp.is_alive():
@@ -268,16 +251,12 @@ def _save_last_step(self, curr_step: int) -> None:
268251
self.states.pop("freqs_cis")
269252

270253
if self.export_dtype != torch.float32:
271-
self.states = {
272-
k: v.to(self.export_dtype) for k, v in self.states.items()
273-
}
254+
self.states = {k: v.to(self.export_dtype) for k, v in self.states.items()}
274255
logger.info(
275-
f"Saving a model weights only checkpoint in {self.export_dtype} "
276-
f"at last step, step {curr_step}."
256+
f"Saving a model weights only checkpoint in {self.export_dtype} " f"at last step, step {curr_step}."
277257
)
278258
else:
279-
logger.info(
280-
f"Saving a full checkpoint at last step, step {curr_step}.")
259+
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
281260

282261
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
283262
self.reset()
@@ -287,18 +266,13 @@ def _should_save(self, curr_step: int, force: bool = False) -> bool:
287266
return False
288267

289268
if not force:
290-
if self.interval_type == IntervalType.STEPS and not (
291-
curr_step % self.interval == 0
292-
):
269+
if self.interval_type == IntervalType.STEPS and not (curr_step % self.interval == 0):
293270
return False
294271
if self.interval_type == IntervalType.SECONDS:
295-
time_sync_result = (time.monotonic() -
296-
self.begin_time) >= self.interval
272+
time_sync_result = (time.monotonic() - self.begin_time) >= self.interval
297273
self.time_sync_result = torch.tensor(int(time_sync_result))
298274
if self.time_sync_work is None:
299-
self.time_sync_work = dist.all_reduce(
300-
self.time_sync_result, group=self.pg, async_op=True
301-
)
275+
self.time_sync_work = dist.all_reduce(self.time_sync_result, group=self.pg, async_op=True)
302276
return False
303277
elif curr_step % 5 == 4:
304278
self.time_sync_work.wait()
@@ -319,31 +293,25 @@ def _should_save(self, curr_step: int, force: bool = False) -> bool:
319293

320294
def _async_wait(self) -> None:
321295
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
322-
logger.debug(
323-
f"Waiting for the background process to finish, {time.monotonic()=}.:.2f"
324-
)
296+
logger.debug(f"Waiting for the background process to finish, {time.monotonic()=}.:.2f")
325297
if not self.mp.is_alive():
326-
raise RuntimeError(
327-
"The checkpoint background process is dead.")
298+
raise RuntimeError("The checkpoint background process is dead.")
328299
_ = self.mp_queue_recv.get()
329300
elif self.async_mode == AsyncMode.ASYNC:
330301
if self.async_future is not None:
331302
self.async_future.result()
332303

333304
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
334305
try:
335-
from torch.distributed._state_dict_utils import (
336-
_copy_state_dict, _create_cpu_state_dict)
306+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
337307
except ImportError as e:
338308
raise ImportError(
339309
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
340310
) from e
341311
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
342312
if self.cpu_offload_state_dict is None:
343313
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
344-
self.cpu_offload_state_dict = _create_cpu_state_dict(
345-
state_dict, pin_memory=True
346-
)
314+
self.cpu_offload_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
347315

348316
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
349317
with torch.cuda.stream(self.staging_stream):
@@ -374,9 +342,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
374342
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
375343
self._async_with_pinned_memory(checkpoint_id)
376344
elif self.async_mode == AsyncMode.ASYNC:
377-
self.async_future = dcp.async_save(
378-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
379-
)
345+
self.async_future = dcp.async_save(self.states, checkpoint_id=checkpoint_id, process_group=self.pg)
380346
else:
381347
dcp.save(self.states, checkpoint_id=checkpoint_id)
382348
self.reset()
@@ -388,16 +354,10 @@ def save(self, curr_step: int, force: bool = False) -> None:
388354
)
389355

390356
def maybe_wait_for_staging(self) -> None:
391-
if (
392-
self.enable_checkpoint
393-
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
394-
and self.staging
395-
):
357+
if self.enable_checkpoint and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM and self.staging:
396358
logger.debug(f"Waiting for staging, {time.monotonic()=:.2f}.")
397359
self.staging_stream.synchronize()
398-
logger.debug(
399-
f"Sending the state dict to the background process, {time.monotonic()=:.2f}."
400-
)
360+
logger.debug(f"Sending the state dict to the background process, {time.monotonic()=:.2f}.")
401361
self.mp_queue_send.put((self.staging_state_dict, self.staging_id))
402362
self.staging = False
403363

@@ -413,8 +373,7 @@ def load(self, step: int = -1) -> bool:
413373
step_counts = []
414374
for filename in os.listdir(self.folder):
415375
match = re.search(r"step-(\d+)", filename)
416-
metadata_probe = os.path.join(
417-
self.folder, filename, ".metadata")
376+
metadata_probe = os.path.join(self.folder, filename, ".metadata")
418377
if match and os.path.isfile(metadata_probe):
419378
step_counts.append(int(match.group(1)))
420379
if not step_counts:
@@ -429,9 +388,7 @@ def load(self, step: int = -1) -> bool:
429388
states,
430389
checkpoint_id=self._create_checkpoint_id(step),
431390
)
432-
logger.info(
433-
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
434-
)
391+
logger.info(f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds.")
435392
return True
436393

437394
def _purge_stale_checkpoints(self):

fmengine/core/checkpoint/train_state.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import torch.distributed as dist
1414
import torch.distributed.checkpoint as dcp
1515
import torch.nn as nn
16-
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
17-
get_model_state_dict,
18-
get_optimizer_state_dict,
19-
set_model_state_dict,
20-
set_optimizer_state_dict)
16+
from torch.distributed.checkpoint.state_dict import (
17+
StateDictOptions,
18+
get_model_state_dict,
19+
get_optimizer_state_dict,
20+
set_model_state_dict,
21+
set_optimizer_state_dict,
22+
)
2123
from torch.distributed.checkpoint.stateful import Stateful
2224
from torch.utils.data import DataLoader
2325

@@ -48,13 +50,8 @@ def state_dict(self) -> Dict[str, Any]:
4850
def load_state_dict(self, state_dict) -> None:
4951
self.step = state_dict["step"].item()
5052
state_dict["global_avg_losses"].seek(0)
51-
self.global_avg_losses = torch.load(
52-
state_dict["global_avg_losses"], weights_only=False
53-
)
53+
self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False)
5454
state_dict["global_max_losses"].seek(0)
55-
self.global_max_losses = torch.load(
56-
state_dict["global_max_losses"], weights_only=False
57-
)
55+
self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False)
5856
state_dict["log_steps"].seek(0)
59-
self.log_steps = torch.load(
60-
state_dict["log_steps"], weights_only=False)
57+
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)

fmengine/core/configs/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .train_config import (CheckpointConfig, FP8Config, OptimizerConfig,
2-
TokenizerConfig, TrainingConfig, TrainJobConfig)
1+
from .train_config import CheckpointConfig, FP8Config, OptimizerConfig, TokenizerConfig, TrainingConfig, TrainJobConfig
32
from .utils import TORCH_DTYPE_MAP, dict_to_config
43

54
__all__ = [
@@ -10,5 +9,5 @@
109
"FP8Config",
1110
"OptimizerConfig",
1211
"TokenizerConfig",
13-
"TrainingConfig"
12+
"TrainingConfig",
1413
]

0 commit comments

Comments
 (0)