diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 2a21887c1..93f45bc71 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -191,12 +191,7 @@ class OptimizerArgs(BaseArgs): # backward hooked optimizer use_optimizer_with_backward_hook: bool = False # class args for optimizer - class_args: dict = { - "lr": 1e-5, - "weight_decay": 0.1, - "betas": [0.9, 0.95], - "eps": 1e-10, - } + class_args: dict = {"lr": 1e-5, "weight_decay": 0.1, "betas": [0.9, 0.95], "eps": 1e-10} def model_post_init(self, __context: Any) -> None: _check_not_None([(self.class_name, "optimizer class_name")]) diff --git a/lm_engine/data/__init__.py b/lm_engine/data/__init__.py index e2848518c..cb22c57eb 100644 --- a/lm_engine/data/__init__.py +++ b/lm_engine/data/__init__.py @@ -16,20 +16,23 @@ from .ibm import get_ibm_dataloaders from .instruction_tuning import AlpacaDataset, DollyDataset, SlimOrcaDataset from .megatron import get_megatron_gpt_dataloaders +from .phonebook import PhonebookDataset from .sampler import BlendedDistributedSampler from .sst2 import SST2Dataset from .utils import collate_fn, custom_iterator, get_next_batch -_DATASETS_LIST = { - "AlpacaDataset": AlpacaDataset, - "DebugDataset": DebugDataset, - "DollyDataset": DollyDataset, - "HuggingFaceDataset": HuggingFaceDataset, - "SlimOrcaDataset": SlimOrcaDataset, - "SST2Dataset": SST2Dataset, +_FINETUNING_DATASETS_MAPPING = { + AlpacaDataset.__name__: AlpacaDataset, + DebugDataset.__name__: DebugDataset, + DollyDataset.__name__: DollyDataset, + HuggingFaceDataset.__name__: HuggingFaceDataset, + SlimOrcaDataset.__name__: SlimOrcaDataset, + SST2Dataset.__name__: SST2Dataset, } +_PRETRAINING_DATASETS_MAPPING = {PhonebookDataset.__name__: PhonebookDataset} + def get_datasets_list( dataset_args_list: list[DatasetArgs], split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE @@ -52,10 +55,10 @@ def get_datasets_list( datasets_list = [] data_sampling_ratios = [] for data_args in dataset_args_list: - if data_args.class_name not in _DATASETS_LIST: + if data_args.class_name not in _FINETUNING_DATASETS_MAPPING: raise ValueError(f"invalid class_name ({data_args.class_name}) for dataset") - dataset = _DATASETS_LIST[data_args.class_name]( + dataset = _FINETUNING_DATASETS_MAPPING[data_args.class_name]( class_args=data_args.class_args, split=split, mode=mode, @@ -155,11 +158,44 @@ def get_finetuning_dataloader( def get_pretraining_dataloaders( - args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int + args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int, mode: Mode ) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]: - if args.datasets[0].class_name == "MegatronDataset": + assert len(args.datasets) == 1 + class_name = args.datasets[0].class_name + + if class_name in _PRETRAINING_DATASETS_MAPPING: + assert args.load_args is None + + train_dataloader = _PRETRAINING_DATASETS_MAPPING[class_name]( + class_args=args.datasets[0].class_args, + split=DatasetSplit.train, + mode=mode, + tokenizer=tokenizer, + data_name="", + input_format="__input__", + output_format="__output__", + max_input_tokens=None, + max_output_tokens=None, + ) + + val_dataloaders = [ + _PRETRAINING_DATASETS_MAPPING[class_name]( + class_args=args.datasets[0].class_args, + split=DatasetSplit.val, + mode=mode, + tokenizer=tokenizer, + data_name="", + input_format="__input__", + output_format="__output__", + max_input_tokens=None, + max_output_tokens=None, + ) + ] + + dataloaders = (train_dataloader, val_dataloaders, val_dataloaders) + elif class_name == "MegatronDataset": dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples) - elif args.datasets[0].class_name == "IBMDataset": + elif class_name == "IBMDataset": dataloaders = get_ibm_dataloaders(args, tokenizer) return dataloaders diff --git a/lm_engine/data/phonebook.py b/lm_engine/data/phonebook.py new file mode 100644 index 000000000..a9b81298d --- /dev/null +++ b/lm_engine/data/phonebook.py @@ -0,0 +1,83 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import itertools +import random +import string + +from tqdm import trange + +from ..enums import DatasetSplit, Mode +from ..tokenizers import TOKENIZER_TYPE +from .base import BaseDataset + + +class PhonebookDataset(BaseDataset): + def __init__( + self, + class_args: dict, + split: DatasetSplit, + mode: Mode, + tokenizer: TOKENIZER_TYPE, + data_name: str, + input_format: str, + output_format: str, + max_input_tokens: int, + max_output_tokens: int, + ) -> PhonebookDataset: + super().__init__( + class_args=class_args, + split=split, + mode=mode, + tokenizer=tokenizer, + data_name=data_name, + input_format=input_format, + output_format=output_format, + max_input_tokens=max_input_tokens, + max_output_tokens=max_output_tokens, + ) + + self.separator_token = "" + + assert not self.do_format_input + assert not self.do_format_output + assert self.max_input_tokens is None + assert self.max_output_tokens is None + assert self.separator_token in tokenizer.get_vocab() + + name_length = self.class_args["name_length"] + num_digits = self.class_args["num_digits"] + seed = self.class_args.get("seed", 42) + + num_total_names = 26**name_length + num_phone_numbers = 10**num_digits + + self.phonebook_size = self.class_args.get("phonebook_size", min(num_total_names, num_phone_numbers)) + + assert ( + min(num_total_names, num_phone_numbers) >= self.phonebook_size + ), f"either {num_total_names} or {num_phone_numbers} is too small!" + + names = list(itertools.product(list(string.ascii_lowercase), repeat=name_length)) + phone_numbers = list(itertools.product(list(string.digits), repeat=num_digits)) + + local_random = random.Random(seed) + local_random.shuffle(names) + local_random.shuffle(phone_numbers) + + names = names[: self.phonebook_size] + phone_numbers = phone_numbers[: self.phonebook_size] + + self.examples = [] + for i in trange(self.phonebook_size): + sample = "".join(names[i]) + self.separator_token + "".join(phone_numbers[i]) + sample = tokenizer(sample, add_special_tokens=False) + sample += [tokenizer.eos_token_id] + + self.examples.append(sample) + + def __len__(self) -> int: + return self.phonebook_size diff --git a/lm_engine/tokenizers/alpha_numeric.py b/lm_engine/tokenizers/alpha_numeric.py index 569b7de35..d89372ac4 100644 --- a/lm_engine/tokenizers/alpha_numeric.py +++ b/lm_engine/tokenizers/alpha_numeric.py @@ -71,19 +71,27 @@ def __call__( def _get_token_id(self, x: str) -> None: assert isinstance(x, str) - assert len(x) == 1 - xid = ord(x) - - if self._0 <= xid <= self._9: - y = xid - self._0 - elif self.a <= xid <= self.z: - y = xid - self.a + 10 - elif self.A <= xid <= self.Z: - y = xid - self.A + 36 - elif xid == self.eos_token: + if len(x) == 1: + xid = ord(x) + + if self._0 <= xid <= self._9: + y = xid - self._0 + elif self.a <= xid <= self.z: + y = xid - self.a + 10 + elif self.A <= xid <= self.Z: + y = xid - self.A + 36 + else: + raise ValueError(f"unexpected token ({x})") + elif x == self.eos_token: y = self.eos_token_id + elif x in self.special_tokens: + y = self.special_tokens[x] else: raise ValueError(f"unexpected token ({x})") return y + + def add_special_tokens(self, special_tokens: dict) -> None: + for i, token in enumerate(special_tokens["additional_special_tokens"]): + self.special_tokens[token] = self.eos_token_id + i + 1