Skip to content

Commit

Permalink
Dataset splitting and mixing (#50)
Browse files Browse the repository at this point in the history
* allow logger to take name

* first draft

* allow ratios

* shorter debug string for data files

* patch logger world info when not dist

* modify train script

* remove unused columns

* drop misaligning columns
  • Loading branch information
Jackmin801 authored Oct 7, 2024
1 parent afab101 commit 236ac19
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 30 deletions.
142 changes: 127 additions & 15 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
from functools import partial
from typing import Any, Generator
from typing import Any, Generator, Optional, List, Dict, Union
from pydantic_config import BaseConfig
from zeroband.utils.logging import get_logger

import torch
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset
from torch.utils.data import IterableDataset, Dataset
from torchdata.stateful_dataloader import StatefulDataLoader

from datasets import load_dataset
from datasets import load_dataset, interleave_datasets, load_dataset_builder, BuilderConfig
from datasets.distributed import split_dataset_by_node
import functools

TEST_VOCAB_SIZE = 1024

# TODO sami: make sure the init of the model is the same on all rank

logger = get_logger(__name__)


class DataConfig(BaseConfig):
dataset_name_or_paths: str = "allenai/c4:en"
val_dataset_name_or_paths: Optional[str] = None
seq_length: int = 1024
fake: bool = False
num_workers: int = 4
streaming: bool = True
max_train_samples: Optional[int] = None
max_eval_samples: Optional[int] = None
dataset_ratio: Optional[str] = None
data_rank: Optional[int] = None
data_world_size: Optional[int] = None


class FakeTokenizedDataset(IterableDataset):
"""This is a dummy dataset that generates random sequences of length seq_len and vocab_size"""
Expand Down Expand Up @@ -61,28 +80,121 @@ def _collate_fn_causal_mask(
return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)}


def get_dataloader(
tokenizer, world_size: int, rank: int, seq_length: int, batch_size: int, num_workers: int, fake_data: bool
) -> DataLoader:
if fake_data:
train_dataset = FakeTokenizedDataset(seq_length, TEST_VOCAB_SIZE)
def get_dataloader(tokenizer, world_size: int, rank: int, batch_size: int, data_config: DataConfig) -> DataLoader:
if data_config.fake:
train_dataset = FakeTokenizedDataset(data_config.seq_length, TEST_VOCAB_SIZE)
else:
ds = load_dataset("allenai/c4", "en", streaming=True)
ds = load_all_datasets(data_config=data_config, split="train")

def tokenize_function(data):
outputs = tokenizer(data["text"], truncation=True, max_length=seq_length)
outputs = tokenizer(data["text"], truncation=True, max_length=data_config.seq_length)
return outputs

tokenized_datasets = ds.map(
tokenize_function, batched=True, remove_columns=["text", "timestamp", "url", "attention_mask"]
)["train"]
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "attention_mask"])
train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank)

data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100)
data_collator = collate_causal_mask(
max_seq_length=data_config.seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100
)

return StatefulDataLoader(
train_dataset,
collate_fn=data_collator,
batch_size=batch_size,
num_workers=num_workers,
num_workers=data_config.num_workers,
)


@functools.lru_cache(maxsize=None)
def _get_ds_config_dict(path: str, name: Optional[str] = None) -> Dict[str, BuilderConfig]:
ds_builder = load_dataset_builder(path=path, name=name)
return ds_builder.builder_configs


def _get_datafiles(path: str, name: Optional[str] = None, split: str = "train") -> List[str]:
builder_config = _get_ds_config_dict(path=path, name=name)
if name is None:
if "default" not in builder_config:
logger.warning(f"Default config not found for {path}. Using first config.")
name = next(iter(builder_config.keys()))
else:
name = "default"
return builder_config[name].data_files[split]


def _nice_print(kwargs: Dict[str, Union[str, List[str]]]) -> str:
def _foo(a):
if isinstance(a, list):
return str(a[:5]) + "..." + str(a[-5:]) if len(a) > 10 else str(a)
return str(a)

return str({k: _foo(v) for k, v in kwargs.items()})


def _load_datasets(
dataset_names: str,
split: str,
data_rank: Optional[int] = None,
data_world_size: Optional[int] = None,
streaming: bool = True,
probabilities: Optional[List[float]] = None,
) -> Dataset:
logger.debug(dataset_names)
ds_args = []
for _ds in dataset_names.split(","):
_ds_name, _, _ds_config = _ds.partition(":")
_ds_args = {"path": _ds_name}
if _ds_config:
_ds_args["name"] = _ds_config
if data_rank is not None and data_world_size is not None:
_data_files = _get_datafiles(_ds_name, _ds_config, split)
_ds_args["data_files"] = _data_files[data_rank::data_world_size]
ds_args.append(_ds_args)

logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args)))
logger.debug(f"Probabilities: {probabilities}")
logger.debug(f"Loading datasets{' in streaming mode' if streaming else ''}")
datasets = []
for ds_arg in ds_args:
logger.debug(f"Loading dataset: {ds_arg}")
_ds = load_dataset(**ds_arg, split=split, streaming=streaming)
_ds = _ds.remove_columns([i for i in _ds.column_names if i not in ["text"]])
datasets.append(_ds)
logger.debug(f"Loaded dataset: {ds_arg}")

ds = interleave_datasets(
datasets=datasets,
probabilities=probabilities,
)
logger.info(f"Loaded datasets ({split})")
return ds


def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]:
if data_config.dataset_ratio is None:
return None
if len(data_config.dataset_name_or_paths.split(",")) != len(data_config.dataset_ratio.split(":")):
raise ValueError("Number of datasets and dataset ratios must be the same")
nums = [float(i) for i in data_config.dataset_ratio.split(":")]
denom = sum(nums)
return [i / denom for i in nums]


def load_all_datasets(data_config: DataConfig, split: str, max_samples: Optional[int] = None) -> IterableDataset:
"""Load all datasets and interleave them"""
if max_samples is not None and not data_config.streaming:
split = f"{split}[:{max_samples}]"
ds = _load_datasets(
dataset_names=data_config.dataset_name_or_paths,
split=split,
data_rank=data_config.data_rank,
data_world_size=data_config.data_world_size,
streaming=data_config.streaming,
probabilities=_get_probabilities(data_config),
)
if max_samples is not None and data_config.streaming:
if data_config.max_train_samples is not None:
ds = ds.take(data_config.max_train_samples)
logger.info(f"Train dataset:\n{ds}")

return ds
16 changes: 4 additions & 12 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,14 @@
from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy
from zeroband.utils.activation_ckpt import apply_ac_ckpt
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig
from zeroband.models.llama import get_model
from zeroband.utils.profiler import MemoryProfiler
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from zeroband.checkpoint import CkptManager, TrainingProgress


class DataConfig(BaseConfig):
seq_length: int = 1024
fake: bool = False
num_workers: int = 4


class OptimConfig(BaseConfig):
lr: float = 4e-4
weight_decay: float = 0.1
Expand Down Expand Up @@ -112,12 +106,10 @@ def train(config: Config):

train_dataloader = get_dataloader(
tokenizer=tokenizer,
world_size=world_info.world_size * world_info.global_world_size,
rank=world_info.rank + world_info.global_rank * world_info.global_world_size,
seq_length=config.data.seq_length,
world_size=world_info.world_size,
rank=world_info.rank,
batch_size=config.train.micro_bs,
num_workers=config.data.num_workers,
fake_data=config.data.fake,
data_config=config.data,
)

model, model_config = get_model(
Expand Down
13 changes: 10 additions & 3 deletions src/zeroband/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import logging
import os

Expand All @@ -24,13 +25,19 @@ def format(self, record):
return formatter.format(record)


def get_logger():
def get_logger(name: Optional[str] = None) -> logging.Logger:
global logger # Add this line to modify the global logger variable
if logger is not None:
return logger

world_info = get_world_info()
logger = logging.getLogger(__name__)
try:
world_info = get_world_info()
except KeyError:
from zeroband.utils.world_info import WorldInfo

world_info = WorldInfo.__new__(WorldInfo)
world_info.local_rank = 0
logger = logging.getLogger(name or __name__)

log_level = os.getenv("ZERO_BAND_LOG_LEVEL", "INFO")

Expand Down
78 changes: 78 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch
from zeroband.data import collate_causal_mask
from torch.utils.data import DataLoader
from zeroband.data import load_all_datasets, DataConfig, logger as data_logger
from collections import Counter
from itertools import chain
import pytest
import logging


def test_collate_fn():
Expand All @@ -17,3 +23,75 @@ def test_collate_fn():

assert collated["input_ids"][1].tolist() == [0, 0, 3, 4]
assert collated["labels"][1].tolist() == [0, 3, 4, 1]


@pytest.mark.parametrize(
"ratio, lower, upper",
[
("3:2", 1.2821, 1.7549),
("0.5:1", 0.4247, 0.5886),
],
)
def test_load_all_datasets_vanilla(ratio: str, lower: float, upper: float):
config = DataConfig(
dataset_name_or_paths="Jackmin108/abc-testing:A,Jackmin108/abc-testing:C",
dataset_ratio=ratio,
streaming=True,
fake=False,
)

ds = load_all_datasets(config, "train")
print(ds)

dl = DataLoader(ds, batch_size=256)
batches = [i["text"] for i, _ in zip(dl, range(10))]
assert len(batches) == 10

# Check that the ratio is correct
letter_count = Counter(i[0] for i in chain(*batches))
print(letter_count, letter_count["A"] / letter_count["C"])
assert letter_count["A"] / letter_count["C"] < upper
assert letter_count["A"] / letter_count["C"] > lower


@pytest.mark.parametrize(
"ratio, lower, upper, data_rank, data_world_size",
[
("3:2", 1.2821, 1.7549, 1, 4),
("0.5:1", 0.4247, 0.5886, 0, 3),
],
)
def test_load_all_datasets_data_rank(ratio: str, lower: float, upper: float, data_rank: int, data_world_size: int):
data_logger.setLevel(logging.DEBUG)
config = DataConfig(
dataset_name_or_paths="Jackmin108/abc-testing:A,Jackmin108/abc-testing:C",
dataset_ratio=ratio,
streaming=True,
fake=False,
data_world_size=data_world_size,
data_rank=data_rank,
)

ds = load_all_datasets(config, "train")
print(ds)

dl = DataLoader(ds, batch_size=256)
batches = [i["text"] for i, _ in zip(dl, range(10))]
assert len(batches) == 10

# Check that the ratio is correct
letter_count = Counter(i[0] for i in chain(*batches))
print(letter_count, letter_count["A"] / letter_count["C"])
assert letter_count["A"] / letter_count["C"] < upper
assert letter_count["A"] / letter_count["C"] > lower

c_num_set = {int(i[1:]) for i in chain(*batches) if i[0] == "C"}
a_num_set = {int(i[1:]) for i in chain(*batches) if i[0] == "A"}

# Check that the data is correctly sharded
first_a_shard = set(range(data_rank * (2**12), (data_rank + 1) * (2**12)))
first_10_c_shard = set()
for i in range(data_rank, data_world_size * 10, data_world_size):
first_10_c_shard = first_10_c_shard.union(set(range(i * (2**8), (i + 1) * (2**8))))
assert all(i in first_a_shard for i in a_num_set)
assert all(i in first_10_c_shard for i in c_num_set)

0 comments on commit 236ac19

Please sign in to comment.