From 236ac19318e4f5d5ca30637489c8bb01b745cab4 Mon Sep 17 00:00:00 2001 From: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Date: Mon, 7 Oct 2024 00:17:02 -0700 Subject: [PATCH] Dataset splitting and mixing (#50) * allow logger to take name * first draft * allow ratios * shorter debug string for data files * patch logger world info when not dist * modify train script * remove unused columns * drop misaligning columns --- src/zeroband/data.py | 142 ++++++++++++++++++++++++++++++---- src/zeroband/train.py | 16 +--- src/zeroband/utils/logging.py | 13 +++- tests/test_data.py | 78 +++++++++++++++++++ 4 files changed, 219 insertions(+), 30 deletions(-) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index 1a093d1c..f45c9c2e 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -1,18 +1,37 @@ from functools import partial -from typing import Any, Generator +from typing import Any, Generator, Optional, List, Dict, Union +from pydantic_config import BaseConfig +from zeroband.utils.logging import get_logger import torch from torch.utils.data import DataLoader -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, Dataset from torchdata.stateful_dataloader import StatefulDataLoader -from datasets import load_dataset +from datasets import load_dataset, interleave_datasets, load_dataset_builder, BuilderConfig from datasets.distributed import split_dataset_by_node +import functools TEST_VOCAB_SIZE = 1024 # TODO sami: make sure the init of the model is the same on all rank +logger = get_logger(__name__) + + +class DataConfig(BaseConfig): + dataset_name_or_paths: str = "allenai/c4:en" + val_dataset_name_or_paths: Optional[str] = None + seq_length: int = 1024 + fake: bool = False + num_workers: int = 4 + streaming: bool = True + max_train_samples: Optional[int] = None + max_eval_samples: Optional[int] = None + dataset_ratio: Optional[str] = None + data_rank: Optional[int] = None + data_world_size: Optional[int] = None + class FakeTokenizedDataset(IterableDataset): """This is a dummy dataset that generates random sequences of length seq_len and vocab_size""" @@ -61,28 +80,121 @@ def _collate_fn_causal_mask( return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)} -def get_dataloader( - tokenizer, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int, fake_data: bool -) -> DataLoader: - if fake_data: - train_dataset = FakeTokenizedDataset(seq_length, TEST_VOCAB_SIZE) +def get_dataloader(tokenizer, world_size: int, rank: int, batch_size: int, data_config: DataConfig) -> DataLoader: + if data_config.fake: + train_dataset = FakeTokenizedDataset(data_config.seq_length, TEST_VOCAB_SIZE) else: - ds = load_dataset("allenai/c4", "en", streaming=True) + ds = load_all_datasets(data_config=data_config, split="train") def tokenize_function(data): - outputs = tokenizer(data["text"], truncation=True, max_length=seq_length) + outputs = tokenizer(data["text"], truncation=True, max_length=data_config.seq_length) return outputs - tokenized_datasets = ds.map( - tokenize_function, batched=True, remove_columns=["text", "timestamp", "url", "attention_mask"] - )["train"] + tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "attention_mask"]) train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) - data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100) + data_collator = collate_causal_mask( + max_seq_length=data_config.seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100 + ) return StatefulDataLoader( train_dataset, collate_fn=data_collator, batch_size=batch_size, - num_workers=num_workers, + num_workers=data_config.num_workers, ) + + +@functools.lru_cache(maxsize=None) +def _get_ds_config_dict(path: str, name: Optional[str] = None) -> Dict[str, BuilderConfig]: + ds_builder = load_dataset_builder(path=path, name=name) + return ds_builder.builder_configs + + +def _get_datafiles(path: str, name: Optional[str] = None, split: str = "train") -> List[str]: + builder_config = _get_ds_config_dict(path=path, name=name) + if name is None: + if "default" not in builder_config: + logger.warning(f"Default config not found for {path}. Using first config.") + name = next(iter(builder_config.keys())) + else: + name = "default" + return builder_config[name].data_files[split] + + +def _nice_print(kwargs: Dict[str, Union[str, List[str]]]) -> str: + def _foo(a): + if isinstance(a, list): + return str(a[:5]) + "..." + str(a[-5:]) if len(a) > 10 else str(a) + return str(a) + + return str({k: _foo(v) for k, v in kwargs.items()}) + + +def _load_datasets( + dataset_names: str, + split: str, + data_rank: Optional[int] = None, + data_world_size: Optional[int] = None, + streaming: bool = True, + probabilities: Optional[List[float]] = None, +) -> Dataset: + logger.debug(dataset_names) + ds_args = [] + for _ds in dataset_names.split(","): + _ds_name, _, _ds_config = _ds.partition(":") + _ds_args = {"path": _ds_name} + if _ds_config: + _ds_args["name"] = _ds_config + if data_rank is not None and data_world_size is not None: + _data_files = _get_datafiles(_ds_name, _ds_config, split) + _ds_args["data_files"] = _data_files[data_rank::data_world_size] + ds_args.append(_ds_args) + + logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args))) + logger.debug(f"Probabilities: {probabilities}") + logger.debug(f"Loading datasets{' in streaming mode' if streaming else ''}") + datasets = [] + for ds_arg in ds_args: + logger.debug(f"Loading dataset: {ds_arg}") + _ds = load_dataset(**ds_arg, split=split, streaming=streaming) + _ds = _ds.remove_columns([i for i in _ds.column_names if i not in ["text"]]) + datasets.append(_ds) + logger.debug(f"Loaded dataset: {ds_arg}") + + ds = interleave_datasets( + datasets=datasets, + probabilities=probabilities, + ) + logger.info(f"Loaded datasets ({split})") + return ds + + +def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]: + if data_config.dataset_ratio is None: + return None + if len(data_config.dataset_name_or_paths.split(",")) != len(data_config.dataset_ratio.split(":")): + raise ValueError("Number of datasets and dataset ratios must be the same") + nums = [float(i) for i in data_config.dataset_ratio.split(":")] + denom = sum(nums) + return [i / denom for i in nums] + + +def load_all_datasets(data_config: DataConfig, split: str, max_samples: Optional[int] = None) -> IterableDataset: + """Load all datasets and interleave them""" + if max_samples is not None and not data_config.streaming: + split = f"{split}[:{max_samples}]" + ds = _load_datasets( + dataset_names=data_config.dataset_name_or_paths, + split=split, + data_rank=data_config.data_rank, + data_world_size=data_config.data_world_size, + streaming=data_config.streaming, + probabilities=_get_probabilities(data_config), + ) + if max_samples is not None and data_config.streaming: + if data_config.max_train_samples is not None: + ds = ds.take(data_config.max_train_samples) + logger.info(f"Train dataset:\n{ds}") + + return ds diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ae3d443b..7ac0ce93 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -24,7 +24,7 @@ from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy from zeroband.utils.activation_ckpt import apply_ac_ckpt from zeroband.utils.monitor import WandbMonitor, DummyMonitor -from zeroband.data import TEST_VOCAB_SIZE, get_dataloader +from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig from zeroband.models.llama import get_model from zeroband.utils.profiler import MemoryProfiler from zeroband.utils.world_info import get_world_info @@ -32,12 +32,6 @@ from zeroband.checkpoint import CkptManager, TrainingProgress -class DataConfig(BaseConfig): - seq_length: int = 1024 - fake: bool = False - num_workers: int = 4 - - class OptimConfig(BaseConfig): lr: float = 4e-4 weight_decay: float = 0.1 @@ -112,12 +106,10 @@ def train(config: Config): train_dataloader = get_dataloader( tokenizer=tokenizer, - world_size=world_info.world_size * world_info.global_world_size, - rank=world_info.rank + world_info.global_rank * world_info.global_world_size, - seq_length=config.data.seq_length, + world_size=world_info.world_size, + rank=world_info.rank, batch_size=config.train.micro_bs, - num_workers=config.data.num_workers, - fake_data=config.data.fake, + data_config=config.data, ) model, model_config = get_model( diff --git a/src/zeroband/utils/logging.py b/src/zeroband/utils/logging.py index 2a64339f..9e5f6548 100644 --- a/src/zeroband/utils/logging.py +++ b/src/zeroband/utils/logging.py @@ -1,3 +1,4 @@ +from typing import Optional import logging import os @@ -24,13 +25,19 @@ def format(self, record): return formatter.format(record) -def get_logger(): +def get_logger(name: Optional[str] = None) -> logging.Logger: global logger # Add this line to modify the global logger variable if logger is not None: return logger - world_info = get_world_info() - logger = logging.getLogger(__name__) + try: + world_info = get_world_info() + except KeyError: + from zeroband.utils.world_info import WorldInfo + + world_info = WorldInfo.__new__(WorldInfo) + world_info.local_rank = 0 + logger = logging.getLogger(name or __name__) log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO") diff --git a/tests/test_data.py b/tests/test_data.py index 86440a12..eb71bf0f 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,5 +1,11 @@ import torch from zeroband.data import collate_causal_mask +from torch.utils.data import DataLoader +from zeroband.data import load_all_datasets, DataConfig, logger as data_logger +from collections import Counter +from itertools import chain +import pytest +import logging def test_collate_fn(): @@ -17,3 +23,75 @@ def test_collate_fn(): assert collated["input_ids"][1].tolist() == [0, 0, 3, 4] assert collated["labels"][1].tolist() == [0, 3, 4, 1] + + +@pytest.mark.parametrize( + "ratio, lower, upper", + [ + ("3:2", 1.2821, 1.7549), + ("0.5:1", 0.4247, 0.5886), + ], +) +def test_load_all_datasets_vanilla(ratio: str, lower: float, upper: float): + config = DataConfig( + dataset_name_or_paths="Jackmin108/abc-testing:A,Jackmin108/abc-testing:C", + dataset_ratio=ratio, + streaming=True, + fake=False, + ) + + ds = load_all_datasets(config, "train") + print(ds) + + dl = DataLoader(ds, batch_size=256) + batches = [i["text"] for i, _ in zip(dl, range(10))] + assert len(batches) == 10 + + # Check that the ratio is correct + letter_count = Counter(i[0] for i in chain(*batches)) + print(letter_count, letter_count["A"] / letter_count["C"]) + assert letter_count["A"] / letter_count["C"] < upper + assert letter_count["A"] / letter_count["C"] > lower + + +@pytest.mark.parametrize( + "ratio, lower, upper, data_rank, data_world_size", + [ + ("3:2", 1.2821, 1.7549, 1, 4), + ("0.5:1", 0.4247, 0.5886, 0, 3), + ], +) +def test_load_all_datasets_data_rank(ratio: str, lower: float, upper: float, data_rank: int, data_world_size: int): + data_logger.setLevel(logging.DEBUG) + config = DataConfig( + dataset_name_or_paths="Jackmin108/abc-testing:A,Jackmin108/abc-testing:C", + dataset_ratio=ratio, + streaming=True, + fake=False, + data_world_size=data_world_size, + data_rank=data_rank, + ) + + ds = load_all_datasets(config, "train") + print(ds) + + dl = DataLoader(ds, batch_size=256) + batches = [i["text"] for i, _ in zip(dl, range(10))] + assert len(batches) == 10 + + # Check that the ratio is correct + letter_count = Counter(i[0] for i in chain(*batches)) + print(letter_count, letter_count["A"] / letter_count["C"]) + assert letter_count["A"] / letter_count["C"] < upper + assert letter_count["A"] / letter_count["C"] > lower + + c_num_set = {int(i[1:]) for i in chain(*batches) if i[0] == "C"} + a_num_set = {int(i[1:]) for i in chain(*batches) if i[0] == "A"} + + # Check that the data is correctly sharded + first_a_shard = set(range(data_rank * (2**12), (data_rank + 1) * (2**12))) + first_10_c_shard = set() + for i in range(data_rank, data_world_size * 10, data_world_size): + first_10_c_shard = first_10_c_shard.union(set(range(i * (2**8), (i + 1) * (2**8)))) + assert all(i in first_a_shard for i in a_num_set) + assert all(i in first_10_c_shard for i in c_num_set)