Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lm_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,7 @@ class OptimizerArgs(BaseArgs):
# backward hooked optimizer
use_optimizer_with_backward_hook: bool = False
# class args for optimizer
class_args: dict = {
"lr": 1e-5,
"weight_decay": 0.1,
"betas": [0.9, 0.95],
"eps": 1e-10,
}
class_args: dict = {"lr": 1e-5, "weight_decay": 0.1, "betas": [0.9, 0.95], "eps": 1e-10}

def model_post_init(self, __context: Any) -> None:
_check_not_None([(self.class_name, "optimizer class_name")])
Expand Down
60 changes: 48 additions & 12 deletions lm_engine/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
from .ibm import get_ibm_dataloaders
from .instruction_tuning import AlpacaDataset, DollyDataset, SlimOrcaDataset
from .megatron import get_megatron_gpt_dataloaders
from .phonebook import PhonebookDataset
from .sampler import BlendedDistributedSampler
from .sst2 import SST2Dataset
from .utils import collate_fn, custom_iterator, get_next_batch


_DATASETS_LIST = {
"AlpacaDataset": AlpacaDataset,
"DebugDataset": DebugDataset,
"DollyDataset": DollyDataset,
"HuggingFaceDataset": HuggingFaceDataset,
"SlimOrcaDataset": SlimOrcaDataset,
"SST2Dataset": SST2Dataset,
_FINETUNING_DATASETS_MAPPING = {
AlpacaDataset.__name__: AlpacaDataset,
DebugDataset.__name__: DebugDataset,
DollyDataset.__name__: DollyDataset,
HuggingFaceDataset.__name__: HuggingFaceDataset,
SlimOrcaDataset.__name__: SlimOrcaDataset,
SST2Dataset.__name__: SST2Dataset,
}

_PRETRAINING_DATASETS_MAPPING = {PhonebookDataset.__name__: PhonebookDataset}


def get_datasets_list(
dataset_args_list: list[DatasetArgs], split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
Expand All @@ -52,10 +55,10 @@ def get_datasets_list(
datasets_list = []
data_sampling_ratios = []
for data_args in dataset_args_list:
if data_args.class_name not in _DATASETS_LIST:
if data_args.class_name not in _FINETUNING_DATASETS_MAPPING:
raise ValueError(f"invalid class_name ({data_args.class_name}) for dataset")

dataset = _DATASETS_LIST[data_args.class_name](
dataset = _FINETUNING_DATASETS_MAPPING[data_args.class_name](
class_args=data_args.class_args,
split=split,
mode=mode,
Expand Down Expand Up @@ -155,11 +158,44 @@ def get_finetuning_dataloader(


def get_pretraining_dataloaders(
args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int
args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int, mode: Mode
) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]:
if args.datasets[0].class_name == "MegatronDataset":
assert len(args.datasets) == 1
class_name = args.datasets[0].class_name

if class_name in _PRETRAINING_DATASETS_MAPPING:
assert args.load_args is None

train_dataloader = _PRETRAINING_DATASETS_MAPPING[class_name](
class_args=args.datasets[0].class_args,
split=DatasetSplit.train,
mode=mode,
tokenizer=tokenizer,
data_name="",
input_format="__input__",
output_format="__output__",
max_input_tokens=None,
max_output_tokens=None,
)

val_dataloaders = [
_PRETRAINING_DATASETS_MAPPING[class_name](
class_args=args.datasets[0].class_args,
split=DatasetSplit.val,
mode=mode,
tokenizer=tokenizer,
data_name="",
input_format="__input__",
output_format="__output__",
max_input_tokens=None,
max_output_tokens=None,
)
]

dataloaders = (train_dataloader, val_dataloaders, val_dataloaders)
elif class_name == "MegatronDataset":
dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples)
elif args.datasets[0].class_name == "IBMDataset":
elif class_name == "IBMDataset":
dataloaders = get_ibm_dataloaders(args, tokenizer)

return dataloaders
Expand Down
83 changes: 83 additions & 0 deletions lm_engine/data/phonebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import itertools
import random
import string

from tqdm import trange

from ..enums import DatasetSplit, Mode
from ..tokenizers import TOKENIZER_TYPE
from .base import BaseDataset


class PhonebookDataset(BaseDataset):
def __init__(
self,
class_args: dict,
split: DatasetSplit,
mode: Mode,
tokenizer: TOKENIZER_TYPE,
data_name: str,
input_format: str,
output_format: str,
max_input_tokens: int,
max_output_tokens: int,
) -> PhonebookDataset:
super().__init__(
class_args=class_args,
split=split,
mode=mode,
tokenizer=tokenizer,
data_name=data_name,
input_format=input_format,
output_format=output_format,
max_input_tokens=max_input_tokens,
max_output_tokens=max_output_tokens,
)

self.separator_token = "<sep>"

assert not self.do_format_input
assert not self.do_format_output
assert self.max_input_tokens is None
assert self.max_output_tokens is None
assert self.separator_token in tokenizer.get_vocab()

name_length = self.class_args["name_length"]
num_digits = self.class_args["num_digits"]
seed = self.class_args.get("seed", 42)

num_total_names = 26**name_length
num_phone_numbers = 10**num_digits

self.phonebook_size = self.class_args.get("phonebook_size", min(num_total_names, num_phone_numbers))

assert (
min(num_total_names, num_phone_numbers) >= self.phonebook_size
), f"either {num_total_names} or {num_phone_numbers} is too small!"

names = list(itertools.product(list(string.ascii_lowercase), repeat=name_length))
phone_numbers = list(itertools.product(list(string.digits), repeat=num_digits))

local_random = random.Random(seed)
local_random.shuffle(names)
local_random.shuffle(phone_numbers)

names = names[: self.phonebook_size]
phone_numbers = phone_numbers[: self.phonebook_size]

self.examples = []
for i in trange(self.phonebook_size):
sample = "".join(names[i]) + self.separator_token + "".join(phone_numbers[i])
sample = tokenizer(sample, add_special_tokens=False)
sample += [tokenizer.eos_token_id]

self.examples.append(sample)

def __len__(self) -> int:
return self.phonebook_size
28 changes: 18 additions & 10 deletions lm_engine/tokenizers/alpha_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,27 @@ def __call__(

def _get_token_id(self, x: str) -> None:
assert isinstance(x, str)
assert len(x) == 1

xid = ord(x)

if self._0 <= xid <= self._9:
y = xid - self._0
elif self.a <= xid <= self.z:
y = xid - self.a + 10
elif self.A <= xid <= self.Z:
y = xid - self.A + 36
elif xid == self.eos_token:
if len(x) == 1:
xid = ord(x)

if self._0 <= xid <= self._9:
y = xid - self._0
elif self.a <= xid <= self.z:
y = xid - self.a + 10
elif self.A <= xid <= self.Z:
y = xid - self.A + 36
else:
raise ValueError(f"unexpected token ({x})")
elif x == self.eos_token:
y = self.eos_token_id
elif x in self.special_tokens:
y = self.special_tokens[x]
else:
raise ValueError(f"unexpected token ({x})")

return y

def add_special_tokens(self, special_tokens: dict) -> None:
for i, token in enumerate(special_tokens["additional_special_tokens"]):
self.special_tokens[token] = self.eos_token_id + i + 1