v0.11.0
π Composer v0.11.0
Composer v0.11.0 is released! Install via pip
:
pip install --upgrade mosaicml==0.11.0
New Features
-
π§° FSDP Beta Support
Composer now supports PyTorch FSDP! PyTorch FSDP is a strategy for distributed training, similar to PyTorch DDP, that distributes work using data-parallelism only. On top of this, FSDP uses model, gradient, and optimizer sharding to dramatically reduce device memory requirements, and enables users to easily scale and train large models.
Here's how easy it is to use FSDP with Composer:
import torch.nn as nn from composer import Trainer class Block (nn.Module): ... # Your custom model class Model(nn.Module): def __init__(self, n_layers): super().__init__() self.blocks = nn.ModuleList([ Block(...) for _ in range(n_layers) ]), self.head = nn.Linear(...) def forward(self, inputs): ... # FSDP Wrap Function def fsdp_wrap_fn(self, module): return isinstance(module, Block) # Activation Checkpointing Function def activation_checkpointing_fn(self, module): return isinstance(module, Block) # ComposerModel wrapper, used by the Trainer # to compute loss, metrics, etc. class MyComposerModel(ComposerModel): def __init__(self, n_layers): super().__init__() self.model = Model(n_layers) ... def forward(self, batch): ... def eval_forward(self, batch, outputs=None): ... def loss(self, outputs, batch): ... # Pass your ComposerModel and fsdp_config into the Trainer composer_model = MyComposerModel(n_layers=3) fsdp_config = { 'sharding_strategy': 'FULL_SHARD', 'min_params': 1e8, 'cpu_offload': False, # Not supported yet 'mixed_precision': 'DEFAULT', 'backward_prefetch': 'BACKWARD_POST', 'activation_checkpointing': False, 'activation_cpu_offload': False, 'verbose': True } trainer = Trainer( model=composer_model, fsdp_config=fsdp_config, ... ) trainer.fit()
For more information, please see our FSDP docs.
-
π° Streaming v0.1
We've spun off Streaming datasets into it's own repository! Streaming datasets is a high-performance drop-in for TorchΒ
IterableDataset
, enabling users to stream training data from cloud based object stores. Streaming is shipping with built-in support for popular open source datasets (ADE20K, C4, COCO, Enwiki, ImageNet, etc.)To get started, install the Streaming PyPi package:
pip install mosaicml-streaming
You can use the streaming Dataset class with the PyTorch native DataLoader class as follows:
import torch from streaming import Dataset dataloader = torch.utils.data.DataLoader(dataset=Dataset(remote='s3://...'))
For more information, please check out the Streaming docs.
-
βπ Simplified Checkpointing Interface
With this release weβve greatly simplified configuration of loading and saving checkpoints in Composer.
To save checkpoints to S3, all you need to do is:
- Specify with
save_folder
your full URI to your save directory destination (e.g.'s3://my-bucket/{run_name}/checkpoints'
) - Optionally, set
save_filename
to the pattern you want for your checkpoint file names
from composer.trainer import Trainer # Checkpoint saving to S3. trainer = Trainer( model=model, save_folder="s3://my-bucket/{run_name}/checkpoints", run_name='my-run', save_interval="1ep", save_filename="ep{epoch}.pt", save_num_checkpoints_to_keep=0, # delete all checkpoints locally ... ) trainer.fit()
Likewise, to load checkpoints from S3, all you have to do is:
- Set
load_path
to the full URI to your desired checkpoint file (e.g.'s3://my-bucket/my-run/checkpoints/epoch13.pt'
)
from composer.trainer import Trainer # Checkpoint loading from S3. new_trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration="10ep", load_path="s3://my-bucket/my-run/checkpoints/ep13.pt", ) new_trainer.fit()
For more information, please see our Checkpointing guide.
- Specify with
-
π³ Improved Distributed Experience
Weβve made it easier to write your own custom distributed entry points by exposing our distributed API. You can now leverage all of our helpful distributed functions and contexts.
For example, let's say we want to need to download a dataset in a distributed training application. To avoid race conditions where different ranks try to write the dataset to the same place, we need to ensure that only rank 0 downloads the dataset first:
import datetime from composer.trainer.devices import DeviceGPU from composer.utils import dist dist.initialize(DeviceGPU(), datetime.timedelta(seconds=30)) # Initialize distributed module if dist.get_local_rank() == 0: # Download dataset on rank zero dataset = download_my_dataset() dist.barrier() # All ranks wait until dataset is downloaded # Create and train your model!
For more information, please check out our Distributed API docs.
Bug Fixes
- fix loss and eval_forward for HF models (#1597)
- add more robust casting to int for fsdp min_params (#1608)
- Deepspeed Docs Typo (#1605)
- Fix mmdet typo (#1618)
- Blurpool idempotent (#1625)
- When model is not on
meta
device, initialization should occur on compute device not CPU (#1623) - Auto resumption (#1615)
- Adjust speed monitor (#1645)
- Hot fix console logging (#1643)
- Lazy Logging + pretty print dict for hparams (#1653)
- Fix many failing notebook tests (#1646)
What's Changed
- Bump coverage[toml] from 6.4.4 to 6.5.0 by @dependabot in #1583
- Bump furo from 2022.9.15 to 2022.9.29 by @dependabot in #1584
- Add English Wikipedia 2020-01-01 dataset by @knighton in #1572
- Add pull request template by @dakinggg in #1588
- Bump ipykernel from 6.15.3 to 6.16.0 by @dependabot in #1587
- Update importlib-metadata requirement from <5,>=4.11.0 to >=5.0,<6 by @dependabot in #1585
- Bump sphinx-argparse from 0.3.1 to 0.3.2 by @dependabot in #1586
- Add step explicitly to ImageVisualizer logging calls by @dakinggg in #1591
- Image viz test by @dakinggg in #1592
- Remove unused fixture by @mvpatel2000 in #1594
- Fixes RandAugment API by @mvpatel2000 in #1596
- fix loss and eval_forward for HF models by @dskhudia in #1597
- Remove tensorflow-io from setup.py by @eracah in #1577
- Fixes enwiki for the newly processed wiki dataset by @dskhudia in #1600
- Change install to all by @mvpatel2000 in #1599
- Remove log level and should_log_artifact by @dakinggg in #1603
- Add more robust casting to int for fsdp min_params by @dblalock in #1608
- Deepspeed Docs Typo by @mvpatel2000 in #1605
- Object store logger refactor by @dakinggg in #1601
- Bump gitpython from 3.1.27 to 3.1.28 by @dependabot in #1609
- Bump tabulate from 0.8.10 to 0.9.0 by @dependabot in #1610
- Log the number of GPUs and nodes Composer running on. by @eracah in #1604
- Update MLPerfCallback for v2.1 by @hanlint in #1607
- Remove object store cls by @dakinggg in #1606
- Add LAMB Optimizer by @hanlint in #1613
- Mmdet adapter by @A-Jacobson in #1545
- Fix mmdet typo by @Landanjs in #1618
- update torchmetrics requirement by @hanlint in #1620
- Add distributed sampler error by @mvpatel2000 in #1598
- Landan/deeplabv3 ade20k example by @Landanjs in #1593
- Upgrade CodeQL Action to version 2 by @karan6181 in #1628
- Blurpool idempotent by @mvpatel2000 in #1625
- Defaulting streaming dataset version to 2 by @karan6181 in #1616
- Abhi/fsdp bugfix 0 11 by @abhi-mosaic in #1623
- Remove warning when
master_port
is auto selected by @abhi-mosaic in #1629 - Remove unused import by @dakinggg in #1630
- Usability improvements to
intitialize_dist()
by @growlix in #1619 - Remove Graph in Auto Grad Accum by @mvpatel2000 in #1631
- Auto resumption by @dakinggg in #1615
- add stop method by @hanlint in #1627
- S3 Checkpoint Saving By URI by @eracah in #1614
- S3 Checkpoint loading from URI by @eracah in #1624
- Add mvpatel2000 as codeowner for algos by @mvpatel2000 in #1640
- Adjust speed monitor by @mvpatel2000 in #1645
- Adding in FSDP Docs by @bcui19 in #1621
- Attempt to fix flaky doctest by @dakinggg in #1647
- Fix Missing Underscores in FSDP Docs by @bcui19 in #1648
- Fixed html path for make host command for docs by @karan6181 in #1642
- Fix hyperparameters logged to console even when progress_bar and log_to_console are False by @eracah in #1643
- Fix ImageNet Example normalization values by @Landanjs in #1641
- Python log level by @dakinggg in #1651
- Changed default logging to WARN for doctests by @eracah in #1644
- Add Event.AFTER_LOAD by @mvpatel2000 in #1652
- Lazy Logging + pretty print dict for hparams by @eracah in #1653
- Fix todo in memory monitor by @mvpatel2000 in #1654
- Tests for Idempotent Surgery by @mvpatel2000 in #1639
- Remove c4 dataset by @mvpatel2000 in #1635
- Update torchmetrics by @hanlint in #1656
- Search index filtered by project by @nqn in #1549
- FSDP Tests by @bcui19 in #1650
- Add composer version to issue template by @dakinggg in #1657
- Fix many failing notebook tests by @dakinggg in #1646
- Re-build the Docker images to resolve pip version error by @bandish-shah in #1655
Full Changelog: v0.10.1...v0.11.0