diff --git a/.gitignore b/.gitignore index cbda94f..b7d5fc9 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ experiments/*.md # DeepSpeed deepspeed_logs/ +*.png diff --git a/configs/sft/sft_full.yaml b/configs/sft/sft_full.yaml index ad2ac66..4bb44e6 100644 --- a/configs/sft/sft_full.yaml +++ b/configs/sft/sft_full.yaml @@ -18,6 +18,8 @@ weight_decay: 0.01 lr_scheduler_type: "cosine" enable_gradient_checkpointing: true warmup_ratio: 0.1 +packing: true +packing_num_proc: 4 # Batch Size train_micro_batch_size_per_gpu: 4 diff --git a/configs/sft/sft_lora.yaml b/configs/sft/sft_lora.yaml index b056311..c77ae4b 100644 --- a/configs/sft/sft_lora.yaml +++ b/configs/sft/sft_lora.yaml @@ -16,6 +16,8 @@ learning_rate: 5e-5 lr_scheduler_type: "cosine" enable_gradient_checkpointing: true warmup_ratio: 0.1 +packing: true +packing_num_proc: 4 # Batch Size train_micro_batch_size_per_gpu: 4 diff --git a/src/bumblecore/cli/arg_parser.py b/src/bumblecore/cli/arg_parser.py index e8ddaf1..af8c614 100644 --- a/src/bumblecore/cli/arg_parser.py +++ b/src/bumblecore/cli/arg_parser.py @@ -47,6 +47,8 @@ def get_args(): action="store_true", default=cfg.get("enable_gradient_checkpointing", False), ) + parser.add_argument("--packing", type=bool, default=cfg.get("packing", True)) + parser.add_argument("--packing_num_proc", type=int, default=cfg.get("packing_num_proc", 4)) parser.add_argument("--warmup_ratio", type=float, default=cfg.get("warmup_ratio", 0.1)) # Batch Size diff --git a/src/bumblecore/config/train_config.py b/src/bumblecore/config/train_config.py index d7b05c8..f202d53 100644 --- a/src/bumblecore/config/train_config.py +++ b/src/bumblecore/config/train_config.py @@ -46,6 +46,10 @@ class TrainConfig: lora_dropout: float = field(default=0.1) lora_target_modules: Optional[Union[List[str], str]] = field(default=None) + # Packing settings + packing: bool = field(default=False) + packing_num_proc: int = field(default=1) + ld_alpha: float = field(default=1.0) pref_beta: float = field(default=0.1) diff --git a/src/bumblecore/data_processing/__init__.py b/src/bumblecore/data_processing/__init__.py index f8d24a9..f229e6a 100644 --- a/src/bumblecore/data_processing/__init__.py +++ b/src/bumblecore/data_processing/__init__.py @@ -1,15 +1,34 @@ -from .datasets import SFTDataset,PretrainDataset,DataCollator,DPODataset,DPOCollator +from .datasets import ( + SFTDataset, PretrainDataset, DataCollator, DPODataset, DPOCollator, + PackingDataCollator, +) from .preprocess import load_pretrain_data,load_sft_data,load_dpo_data from .data_formatter import DataFormatter +from .dataset_utils import ( + show_sample, + get_padding_value, + calculate_matched_group, + split_list, + is_master, + is_distributed, +) __all__ = [ - "SFTDataset", - "PretrainDataset", + "SFTDataset", + "PretrainDataset", "DataCollator", "load_pretrain_data", "load_sft_data", "load_dpo_data", "DataFormatter", "DPODataset", - "DPOCollator" + "DPOCollator", + "PackingDataCollator", + # Utility functions + "show_sample", + "get_padding_value", + "calculate_matched_group", + "split_list", + "is_master", + "is_distributed", ] diff --git a/src/bumblecore/data_processing/dataset_utils.py b/src/bumblecore/data_processing/dataset_utils.py new file mode 100644 index 0000000..dae0788 --- /dev/null +++ b/src/bumblecore/data_processing/dataset_utils.py @@ -0,0 +1,114 @@ +"""Utility functions for dataset processing.""" + +import io +from typing import List, Tuple + +import torch +import torch.distributed as dist +from rich.console import Console +from rich.table import Table +from rich.text import Text +from tqdm import tqdm + + +def show_sample( + input_ids, + labels, + tokenizer, + title="Input and Labels", + left_column="Input IDs", + right_column="Labels" +): + """Display a sample with input_ids and labels in a formatted table.""" + input_ids = input_ids.tolist() + labels = labels.tolist() + + valid_labels_list = [token_id for token_id in labels if token_id != -100] + decoded_input = tokenizer.decode(input_ids) + decoded_labels = tokenizer.decode(valid_labels_list) + + table = Table(show_header=True, show_lines=True, title=title) + table.add_column(left_column, overflow="fold") + table.add_column(right_column, overflow="fold") + + wrapped_input = Text(decoded_input, no_wrap=False, overflow="fold") + wrapped_labels = Text(decoded_labels, no_wrap=False, overflow="fold") + + table.add_row(str(input_ids), str(labels)) + table.add_row(wrapped_input, wrapped_labels) + + with io.StringIO() as buf: + console = Console(file=buf, force_terminal=False) + console.print(table) + output = buf.getvalue() + + tqdm.write(output.rstrip()) + + +def get_padding_value(tokenizer): + """Get the padding token id from tokenizer. + + If pad_token_id is not set, use eos_token_id as fallback. + """ + if tokenizer.pad_token_id is not None: + return tokenizer.pad_token_id + + eos = tokenizer.eos_token_id + return eos[0] if isinstance(eos, list) else eos + + +def calculate_matched_group(sequences: List[Tuple[int, int]], packing_length: int, is_finished: bool = True): + """Bin-packing via First Fit Decreasing (https://arxiv.org/pdf/2404.10830). + + Args: + sequences: List of (index, length) tuples. + packing_length: Maximum length for each pack. + is_finished: Whether this is the last batch. + + Returns: + Tuple of (packed_sequences, remaining_sequences). + packed_sequences is a list of lists, each containing (index, length) tuples. + """ + if len(sequences) == 0: + return [], [] + import binpacking + + # sequences 是 [(index, length), ...] 列表 + # weight_pos=1 表示长度在元组第二个位置 + # 将一组物品分配到多个容量固定的箱子(bins)中,使得每个箱子的总容量不超过指定的最大值。 + sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1) + + # sequences 是列表的列表,每个子列表包含多个 (index, length) 元组 + # 如果不是最后一批,保留最后一个不完整组用于下一批 + if sequences and not is_finished: + sequences, ret_sequences = sequences[:-1], sequences[-1] + else: + ret_sequences = [] + return sequences, ret_sequences + + +def split_list(lst: list, n: int) -> List[list]: + """Split a list into n sublists as evenly as possible. + + Args: + lst: The list to split. + n: Number of parts to split into. + + Returns: + List of n sublists. + """ + # 划分列表为n个子列表,对应n个子进程处理 + k, m = divmod(len(lst), n) + return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] + + +def is_master() -> bool: + """Check if current process is the master process in distributed training.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() == 0 + return True + + +def is_distributed() -> bool: + """Check if running in distributed training mode.""" + return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 diff --git a/src/bumblecore/data_processing/datasets.py b/src/bumblecore/data_processing/datasets.py index 8b0a13a..9fef5a8 100644 --- a/src/bumblecore/data_processing/datasets.py +++ b/src/bumblecore/data_processing/datasets.py @@ -1,53 +1,20 @@ -import io - import torch import torch.distributed as dist -from torch.utils.data import Dataset,get_worker_info -from rich.console import Console -from rich.table import Table -from rich.text import Text +from torch.utils.data import Dataset, get_worker_info from tqdm import tqdm +import math +import multiprocessing as mp +from itertools import chain -def show_sample( - input_ids, - labels, - tokenizer, - title="Input and Labels" , - left_column = "Input IDs", - right_column = "Labels" -): - input_ids = input_ids.tolist() - labels = labels.tolist() - - valid_labels_list = [token_id for token_id in labels if token_id != -100] - decoded_input = tokenizer.decode(input_ids) - decoded_labels = tokenizer.decode(valid_labels_list) - - table = Table(show_header=True, show_lines=True, title=title) - table.add_column(left_column, overflow="fold") - table.add_column(right_column, overflow="fold") - - wrapped_input = Text(decoded_input, no_wrap=False, overflow="fold") - wrapped_labels = Text(decoded_labels, no_wrap=False, overflow="fold") - - table.add_row(str(input_ids), str(labels)) - table.add_row(wrapped_input, wrapped_labels) - - with io.StringIO() as buf: - console = Console(file=buf, force_terminal=False) - console.print(table) - output = buf.getvalue() - - tqdm.write(output.rstrip()) - - -def get_padding_value(tokenizer): - if tokenizer.pad_token_id is not None: - return tokenizer.pad_token_id - - eos = tokenizer.eos_token_id - return eos[0] if isinstance(eos, list) else eos +from .dataset_utils import ( + show_sample, + get_padding_value, + calculate_matched_group, + split_list, + is_master, + is_distributed, +) class PretrainDataset(Dataset): @@ -120,23 +87,136 @@ def __getitem__(self, idx): + class SFTDataset(Dataset): - + PACKING_BATCH_SIZE = 1000 + def __init__( self, train_dataset, tokenizer, max_length, + # ── new packing args ── + packing: bool = False, + packing_num_proc: int = 1, ): self.train_dataset = train_dataset self.tokenizer = tokenizer self.max_length = max_length - self.has_shown_sample = False - def __len__(self): - return len(self.train_dataset) - + if len(train_dataset) == 0: + raise ValueError("train_dataset cannot be empty") + + # ── packing bookkeeping ── + self.packing = packing + self.packing_length = max_length + self.packed_idx = None + self.packed_length = None + + if self.packing: + self.packing_num_proc = min( + packing_num_proc, + max(1, math.ceil(len(train_dataset) / self.PACKING_BATCH_SIZE)), + ) + self._out_queue = mp.Queue() + self._setup_packing() + + # ------------------------------------------------------------------ # + # packing index construction # + # ------------------------------------------------------------------ # + + def _compute_lengths(self) -> list[int]: + """Tokenize every sample once to get its length.""" + lengths = [] + for idx in tqdm(range(len(self.train_dataset)), desc="Computing sequence lengths"): + messages = self.train_dataset[idx]["messages"] + tools = self.train_dataset[idx].get("tools", None) + tokens = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + truncation=True, + max_length=self.max_length, + tools=tools if tools else None, + ) + lengths.append(len(tokens)) + return lengths + + def _setup_packing(self): + """Build packed_idx / packed_length using multi-process bin-packing.""" + if is_master(): + # 计算每条数据的长度 + lengths = self._compute_lengths() + offset = 0 + chunked_lengths = split_list(lengths, self.packing_num_proc) + + # launch workers + for i in range(self.packing_num_proc): + worker = mp.Process( + target=self._create_packed_idx, + args=(i, offset, chunked_lengths[i]), + daemon=True, + ) + worker.start() + offset += len(chunked_lengths[i]) + + # collect results + self.packed_idx = [[] for _ in range(self.packing_num_proc)] + self.packed_length = [[] for _ in range(self.packing_num_proc)] + + desc = ( + "Packing: " + if self.packing_num_proc == 1 + else f"Packing (num_proc={self.packing_num_proc}): " + ) + with tqdm(total=len(lengths), dynamic_ncols=True, desc=desc) as pbar: + finished = 0 + while finished < self.packing_num_proc: + rank, sequences, data_len = self._out_queue.get() + if data_len == -1: # sentinel + finished += 1 + continue + pbar.update(data_len) + # (idx, length) + self.packed_idx[rank] += [[x[0] for x in seq] for seq in sequences] + # sum的结果应该接近packing_length + self.packed_length[rank] += [sum(x[1] for x in seq) for seq in sequences] + + self.packed_idx = list(chain.from_iterable(self.packed_idx)) + self.packed_length = list(chain.from_iterable(self.packed_length)) + else: + self.packed_idx, self.packed_length = None, None + + # broadcast to all ranks + if is_distributed(): + obj_list = [(self.packed_idx, self.packed_length)] + dist.broadcast_object_list(obj_list) + self.packed_idx, self.packed_length = obj_list[0] + + def _create_packed_idx(self, rank: int, offset: int, lengths: list[int]): + """Worker: stream bin-packing results back through self._out_queue.""" + # 这个i + offset 用来定位数据的源数据集中的位置 + data = [(i + offset, length) for i, length in enumerate(lengths)] + i = 0 + input_data: list = [] + while True: + new_data = data[i : i + self.PACKING_BATCH_SIZE] + input_data += new_data + if not input_data: + break + i += self.PACKING_BATCH_SIZE + is_finished = i >= len(data) + sequences, input_data = calculate_matched_group( + input_data, self.packing_length, is_finished=is_finished + ) + # (进程号,packing结果,剩余数据长度) + self._out_queue.put((rank, sequences, len(new_data))) + self._out_queue.put((rank, [], -1)) # sentinel + + # ------------------------------------------------------------------ # + # original SFT logic # + # ------------------------------------------------------------------ # def create_conversation_manually(self, messages, tools): @@ -158,7 +238,6 @@ def create_conversation_manually(self, messages, tools): for i, message in enumerate(messages): if message["role"] == "assistant": - context_with_reply = messages[: i + 1] full_tokens = self.tokenizer.apply_chat_template( context_with_reply, @@ -169,67 +248,119 @@ def create_conversation_manually(self, messages, tools): tools=tools if tools else None, ) reply_end_pos = len(full_tokens) - assistant_masks[current_pos:reply_end_pos] = [1] * (reply_end_pos - current_pos) - else: - - prompt_context = messages[: i + 1] - if message["role"] == "system": continue - - else: - prompt_tokens = self.tokenizer.apply_chat_template( - prompt_context, - tokenize=True, - add_generation_prompt=True, - truncation=True, - max_length=self.max_length, - tools=tools if tools else None, - ) - current_pos = len(prompt_tokens) + prompt_context = messages[: i + 1] + prompt_tokens = self.tokenizer.apply_chat_template( + prompt_context, + tokenize=True, + add_generation_prompt=True, + truncation=True, + max_length=self.max_length, + tools=tools if tools else None, + ) + current_pos = len(prompt_tokens) input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) labels = input_ids.clone() - labels[torch.tensor(assistant_masks, dtype=torch.bool) == 0] = -100 return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - def _show_train_sample(self, input_ids, labels): - if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: - rank = 0 + rank = 0 worker_info = get_worker_info() is_main_worker = (worker_info is None) or (worker_info.id == 0) if rank == 0 and is_main_worker and not self.has_shown_sample: show_sample( input_ids=input_ids, labels=labels, - tokenizer = self.tokenizer, + tokenizer=self.tokenizer, title="SFT Input and Labels", left_column="Input IDs", - right_column="Labels" + right_column="Labels", ) self.has_shown_sample = True - - def __getitem__(self, idx): + + # ------------------------------------------------------------------ # + # __len__ / __getitem__ # + # ------------------------------------------------------------------ # + + def __len__(self): + if self.packing: + return len(self.packed_idx) + return len(self.train_dataset) + + def _process_single_sample(self, idx: int) -> dict: + """Tokenize one sample (shared by normal & packing paths).""" messages = self.train_dataset[idx]["messages"] - tools = self.train_dataset[idx]["tools"] - sample = self.create_conversation_manually(messages, tools) + tools = self.train_dataset[idx].get("tools", None) + return self.create_conversation_manually(messages, tools) - self._show_train_sample( - input_ids=sample["input_ids"], - labels=sample["labels"], - ) + def __getitem__(self, idx): + if self.packing: + return self._getitem_packing(idx) + sample = self._process_single_sample(idx) + # self._show_train_sample(input_ids=sample["input_ids"], labels=sample["labels"]) return sample + # ── packing __getitem__ ────────────────────────────────────────────── + + def _getitem_packing(self, idx): + """ + Concatenate the samples assigned to this pack, add per-sequence + position_ids (reset to 0 at each boundary). + + Returns + ------- + dict with keys: input_ids, attention_mask, labels, position_ids + All tensors are concatenated but NOT padded. + Padding is handled by PackingDataCollator. + + Notes + ----- + * ``position_ids`` resets to 0 at each sequence boundary, which + Flash-Attention-2 / flex-attention can use to build a + block-diagonal mask automatically. + """ + sequence_indices = self.packed_idx[idx] + + all_input_ids = [] + all_labels = [] + all_position_ids = [] + + for seq_idx in sequence_indices: + sample = self._process_single_sample(seq_idx) + input_ids = sample["input_ids"] # (seq_len,) + labels = sample["labels"] # (seq_len,) + seq_len = input_ids.size(0) + + all_input_ids.append(input_ids) + all_labels.append(labels) + all_position_ids.append(torch.arange(seq_len, dtype=torch.long)) + + # concat + input_ids = torch.cat(all_input_ids, dim=0) + labels = torch.cat(all_labels, dim=0) + position_ids = torch.cat(all_position_ids, dim=0) + attention_mask = torch.ones(input_ids.size(0), dtype=torch.long) + + self._show_train_sample(input_ids=input_ids, labels=labels) + + return dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + ) + class DPODataset(Dataset): def __init__(self, train_dataset, tokenizer, max_length): @@ -409,4 +540,70 @@ def _right_pad_to_len(sequences, max_length, padding_value): device=padded.device ) padded = torch.cat([padded, pad_tensor], dim=1) - return padded \ No newline at end of file + return padded + + +class PackingDataCollator: + """ + Data collator for packed sequences. + + Pads a batch of packed sequences (with varying lengths) to the max length + in the batch. Similar to DataCollator but also handles position_ids. + """ + + def __init__(self, tokenizer): + self.input_ids_padding_value = get_padding_value(tokenizer=tokenizer) + + def __call__(self, batch): + if not batch: + return dict( + input_ids=torch.tensor([], dtype=torch.long).reshape(0, 0), + attention_mask=torch.tensor([], dtype=torch.long).reshape(0, 0), + labels=torch.tensor([], dtype=torch.long).reshape(0, 0), + position_ids=torch.tensor([], dtype=torch.long).reshape(0, 0), + ) + + input_ids = [item["input_ids"] for item in batch] + attention_mask = [item["attention_mask"] for item in batch] + labels = [item["labels"] for item in batch] + position_ids = [item["position_ids"] for item in batch] + + # Find max length in this batch + max_length = max(len(x) for x in input_ids) + + # Pad all sequences to max_length + input_ids = self._right_pad_to_len( + input_ids, max_length, self.input_ids_padding_value + ) + attention_mask = self._right_pad_to_len( + attention_mask, max_length, 0 + ) + labels = self._right_pad_to_len( + labels, max_length, -100 + ) + position_ids = self._right_pad_to_len( + position_ids, max_length, 0 + ) + + return dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + ) + + @staticmethod + def _right_pad_to_len(sequences, max_length, padding_value): + padded = torch.nn.utils.rnn.pad_sequence( + sequences, batch_first=True, padding_value=padding_value + ) + if padded.size(1) < max_length: + diff = max_length - padded.size(1) + pad_tensor = torch.full( + (padded.size(0), diff), + padding_value, + dtype=padded.dtype, + device=padded.device + ) + padded = torch.cat([padded, pad_tensor], dim=1) + return padded diff --git a/src/bumblecore/training/base_trainer.py b/src/bumblecore/training/base_trainer.py index 5f7a3b3..bcf69b4 100644 --- a/src/bumblecore/training/base_trainer.py +++ b/src/bumblecore/training/base_trainer.py @@ -171,8 +171,8 @@ def _instantiate_model(): if zero_stage == 3 and self.config.training_stage != "pretrain": self.dschf = HfDeepSpeedConfig(self.deepspeed_config) - with deepspeed.zero.Init(config_dict_or_path=self.deepspeed_config): - return _instantiate_model() + # with deepspeed.zero.Init(config_dict_or_path=self.deepspeed_config): + return _instantiate_model() else: return _instantiate_model() @@ -520,7 +520,9 @@ def _reset_accumulated_results(self): def _handle_logging_and_progress(self, global_step, computation_result): if global_step % self.config.logging_steps == 0: lr = self.model_engine.get_lr()[0] - grad_norm = self.model_engine.get_global_grad_norm().item() + grad_norm = self.model_engine.get_global_grad_norm() + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() epochs_done = min(global_step / self.num_update_steps_per_epoch, math.ceil(self.config.num_epochs)) self.log_metrics(lr, grad_norm, epochs_done, global_step, computation_result) @@ -540,7 +542,17 @@ def compute_loss(self, model_engine, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - output = model_engine(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + + # Support packing: pass position_ids if present + model_kwargs = dict( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + if "position_ids" in batch: + model_kwargs["position_ids"] = batch["position_ids"] + + output = model_engine(**model_kwargs) if self.config.average_tokens_across_devices: count = (labels != -100).sum().float() dist.all_reduce(count, op=dist.ReduceOp.SUM) diff --git a/src/bumblecore/training/sft_trainer.py b/src/bumblecore/training/sft_trainer.py index 8bae3be..6e91489 100644 --- a/src/bumblecore/training/sft_trainer.py +++ b/src/bumblecore/training/sft_trainer.py @@ -1,7 +1,7 @@ from transformers import AutoTokenizer from .base_trainer import BaseTrainer -from ..data_processing import DataFormatter, SFTDataset, load_sft_data, DataCollator +from ..data_processing import DataFormatter, SFTDataset, load_sft_data, DataCollator, PackingDataCollator from ..config import TrainConfig class SFTTrainer(BaseTrainer): @@ -9,7 +9,11 @@ def __init__(self, config: TrainConfig): self.config = config self.format_preprocess_fn = DataFormatter(self.config.training_stage) self.tokenizer, self.train_dataset = self._prepare_datasets() - self.data_collator = DataCollator(self.tokenizer) + # Use PackingDataCollator when packing is enabled, otherwise use standard DataCollator + if self.config.packing: + self.data_collator = PackingDataCollator(self.tokenizer) + else: + self.data_collator = DataCollator(self.tokenizer) super().__init__(config, self.train_dataset, self.tokenizer, self.data_collator) self._print_train_parameters() @@ -29,5 +33,7 @@ def _prepare_datasets(self): messages, tokenizer, max_length=self.config.cutoff_len, + packing=self.config.packing, + packing_num_proc=self.config.packing_num_proc, ) return tokenizer, train_dataset diff --git a/tests/bumblecore/data_processing/test_packing.py b/tests/bumblecore/data_processing/test_packing.py new file mode 100644 index 0000000..5c16a95 --- /dev/null +++ b/tests/bumblecore/data_processing/test_packing.py @@ -0,0 +1,538 @@ +import torch +import pytest +from transformers import AutoTokenizer +from bumblecore.data_processing import ( + SFTDataset, + DataCollator, + PackingDataCollator, +) + +tokenizer = AutoTokenizer.from_pretrained("./models/bumblebee") + + +# ============================== +# PackingDataCollator 测试 +# ============================== + +def test_packing_data_collator_basic(): + """测试 PackingDataCollator 的基本功能""" + collator = PackingDataCollator(tokenizer) + + # 创建包含 position_ids 的批次数据 + batch = [ + { + "input_ids": torch.tensor([1, 10, 20, 30, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1, 1]), + "labels": torch.tensor([1, 10, 20, 30, 2]), + "position_ids": torch.tensor([0, 1, 2, 3, 4]), + }, + { + "input_ids": torch.tensor([1, 15, 25, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + "labels": torch.tensor([1, 15, 25, 2]), + "position_ids": torch.tensor([0, 1, 2, 3]), + }, + ] + + result = collator(batch) + + pad_token_id = tokenizer.pad_token_id + expected = { + "input_ids": torch.tensor([ + [1, 10, 20, 30, 2], + [1, 15, 25, 2, pad_token_id], + ], dtype=torch.long), + "attention_mask": torch.tensor([ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 0], + ], dtype=torch.long), + "labels": torch.tensor([ + [1, 10, 20, 30, 2], + [1, 15, 25, 2, -100], + ], dtype=torch.long), + "position_ids": torch.tensor([ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 0], + ], dtype=torch.long), + } + + assert torch.equal(result["input_ids"], expected["input_ids"]), f"input_ids 不匹配, 结果: {result['input_ids']}, 期望: {expected['input_ids']}" + assert torch.equal(result["attention_mask"], expected["attention_mask"]), "attention_mask 不匹配" + assert torch.equal(result["labels"], expected["labels"]), "labels 不匹配" + assert torch.equal(result["position_ids"], expected["position_ids"]), "position_ids 不匹配" + + +def test_packing_data_collator_with_varying_lengths(): + """测试 PackingDataCollator 处理不同长度序列""" + collator = PackingDataCollator(tokenizer) + + batch = [ + { + "input_ids": torch.tensor([1, 10, 2]), + "attention_mask": torch.tensor([1, 1, 1]), + "labels": torch.tensor([-100, 10, 2]), + "position_ids": torch.tensor([0, 1, 2]), + }, + { + "input_ids": torch.tensor([1, 15, 25, 35, 45, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1, 1, 1]), + "labels": torch.tensor([-100, -100, 25, 35, 45, 2]), + "position_ids": torch.tensor([0, 1, 2, 3, 4, 5]), + }, + { + "input_ids": torch.tensor([1, 12, 2]), + "attention_mask": torch.tensor([1, 1, 1]), + "labels": torch.tensor([-100, 12, 2]), + "position_ids": torch.tensor([0, 1, 2]), + }, + ] + + result = collator(batch) + + pad_token_id = tokenizer.pad_token_id + expected = { + "input_ids": torch.tensor([ + [1, 10, 2, pad_token_id, pad_token_id, pad_token_id], + [1, 15, 25, 35, 45, 2], + [1, 12, 2, pad_token_id, pad_token_id, pad_token_id], + ], dtype=torch.long), + "attention_mask": torch.tensor([ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0], + ], dtype=torch.long), + "labels": torch.tensor([ + [-100, 10, 2, -100, -100, -100], + [-100, -100, 25, 35, 45, 2], + [-100, 12, 2, -100, -100, -100], + ], dtype=torch.long), + "position_ids": torch.tensor([ + [0, 1, 2, 0, 0, 0], + [0, 1, 2, 3, 4, 5], + [0, 1, 2, 0, 0, 0], + ], dtype=torch.long), + } + + assert torch.equal(result["input_ids"], expected["input_ids"]) + assert torch.equal(result["attention_mask"], expected["attention_mask"]) + assert torch.equal(result["labels"], expected["labels"]) + assert torch.equal(result["position_ids"], expected["position_ids"]) + + +def test_packing_data_collator_empty_batch(): + """测试 PackingDataCollator 处理空批次""" + collator = PackingDataCollator(tokenizer) + + batch = [] + + result = collator(batch) + + assert result["input_ids"].shape == (0, 0) + assert result["attention_mask"].shape == (0, 0) + assert result["labels"].shape == (0, 0) + assert result["position_ids"].shape == (0, 0) + + +# ============================== +# SFTDataset Packing 功能测试 +# ============================== + +def test_sft_dataset_packing_disabled(): + """测试 SFTDataset 禁用 packing 时的行为""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 禁用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=False) + + # 检查长度 + assert len(dataset) == len(train_dataset) + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" not in result # 禁用 packing 时不返回 position_ids + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + + +def test_sft_dataset_packing_enabled(): + """测试 SFTDataset 启用 packing 时的基本行为""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度(packing 会改变数据集长度) + assert len(dataset) >= 1 # 至少有一个pack + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result # 启用 packing 时返回 position_ids + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + assert isinstance(result["position_ids"], torch.Tensor) + + # 检查序列长度一致性 + input_ids_len = len(result["input_ids"]) + assert len(result["attention_mask"]) == input_ids_len + assert len(result["labels"]) == input_ids_len + assert len(result["position_ids"]) == input_ids_len + + +def test_sft_dataset_packing_with_single_sample(): + """测试 SFTDataset packing 处理单个样本""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度 + assert len(dataset) == 1 # 单个样本应该只有一个pack + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查 position_ids 是否正确重置 + position_ids = result["position_ids"] + assert torch.equal(position_ids, torch.arange(len(position_ids), dtype=torch.long)) + + +def test_sft_dataset_packing_position_ids_reset(): + """测试 SFTDataset packing 中 position_ids 的重置逻辑""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "Short message."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "Another short message."}, + {"role": "user", "content": "Hey"}, + {"role": "assistant", "content": "Hi there"}, + ], + "tools": None, + } + ] + + max_length = 100 # 设置较小的max_length以触发packing + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 获取样本 + result = dataset[0] + + # 检查 position_ids 是否正确重置 + position_ids = result["position_ids"] + + # 获取 packed_idx 中的实际顺序 + sequence_indices = dataset.packed_idx[0] + + # 根据实际顺序验证 position_ids + offset = 0 + for seq_idx in sequence_indices: + sample = dataset._process_single_sample(seq_idx) + seq_len = len(sample["input_ids"]) + + # 验证当前序列的 position_ids 正确重置为 [0, 1, ..., seq_len-1] + expected_position_ids = torch.arange(seq_len, dtype=torch.long) + actual_position_ids = position_ids[offset:offset + seq_len] + assert torch.equal(actual_position_ids, expected_position_ids), \ + f"Sequence {seq_idx}: expected {expected_position_ids.tolist()}, got {actual_position_ids.tolist()}" + + offset += seq_len + + # 验证总长度正确 + assert offset == len(position_ids), f"Total length mismatch: {offset} vs {len(position_ids)}" + + +def test_sft_dataset_packing_with_tools(): + """测试 SFTDataset packing 带 tools 的情况""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Calculate 2+2"}, + {"role": "assistant", "content": "4"}, + ], + "tools": [{"name": "calculator"}], + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 3*3?"}, + {"role": "assistant", "content": "9"}, + ], + "tools": [{"name": "calculator"}], + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度 + assert len(dataset) >= 1 + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + assert isinstance(result["position_ids"], torch.Tensor) + + +def test_sft_dataset_packing_labels_consistency(): + """测试 SFTDataset packing 中 labels 的一致性""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 获取样本 + result = dataset[0] + + input_ids = result["input_ids"] + labels = result["labels"] + attention_mask = result["attention_mask"] + + # 检查 labels 中非 -100 的位置与 input_ids 一致 + non_negative_100_mask = labels != -100 + assert torch.equal(input_ids[non_negative_100_mask], labels[non_negative_100_mask]) + + # 检查 attention_mask 的有效性 + assert torch.all(attention_mask == 1) # packing 时 attention_mask 应该全为1 + + +def test_sft_dataset_packing_edge_cases(): + """测试 SFTDataset packing 的边缘情况""" + + # 测试空数据集 + empty_dataset = [] + with pytest.raises(ValueError, match="train_dataset cannot be empty"): + dataset = SFTDataset(empty_dataset, tokenizer, max_length=256, packing=True) + + # 测试单个长样本 + long_message_dataset = [ + { + "messages": [ + {"role": "system", "content": "A" * 1000}, # 很长的消息 + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 50 # 设置较小的max_length + + dataset = SFTDataset(long_message_dataset, tokenizer, max_length, packing=True) + result = dataset[0] + + # 检查序列长度不超过max_length + assert len(result["input_ids"]) <= max_length + + +# ============================== +# 集成测试:PackingDataCollator + SFTDataset +# ============================== + +def test_packing_integration(): + """测试 PackingDataCollator 和 SFTDataset 的集成""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 创建带 packing 的 dataset + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 创建 PackingDataCollator + collator = PackingDataCollator(tokenizer) + + # 创建批次 + batch = [dataset[i] for i in range(min(2, len(dataset)))] + + # 使用 collator 处理批次 + result = collator(batch) + + # 检查结果 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查批次维度 + batch_size = len(batch) + assert result["input_ids"].shape[0] == batch_size + assert result["attention_mask"].shape[0] == batch_size + assert result["labels"].shape[0] == batch_size + assert result["position_ids"].shape[0] == batch_size + + +def test_packing_vs_standard_collator(): + """比较 PackingDataCollator 和标准 DataCollator 的区别""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 创建带 packing 的 dataset + dataset_packing = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + dataset_standard = SFTDataset(train_dataset, tokenizer, max_length, packing=False) + + # 创建两种 collator + packing_collator = PackingDataCollator(tokenizer) + standard_collator = DataCollator(tokenizer) + + # 获取样本 + sample_packing = dataset_packing[0] + sample_standard = dataset_standard[0] + + # 检查 packing 样本包含 position_ids + assert "position_ids" in sample_packing + assert "position_ids" not in sample_standard + + # 使用 collator 处理 + batch_packing = packing_collator([sample_packing]) + batch_standard = standard_collator([sample_standard]) + + # 检查 packing collator 返回 position_ids + assert "position_ids" in batch_packing + assert "position_ids" not in batch_standard + + +if __name__ == "__main__": + # 运行测试 + test_packing_data_collator_basic() + test_packing_data_collator_with_varying_lengths() + test_packing_data_collator_empty_batch() + test_sft_dataset_packing_disabled() + test_sft_dataset_packing_enabled() + test_sft_dataset_packing_with_single_sample() + test_sft_dataset_packing_position_ids_reset() + test_sft_dataset_packing_with_tools() + test_sft_dataset_packing_labels_consistency() + test_sft_dataset_packing_edge_cases() + test_packing_integration() + test_packing_vs_standard_collator() + + print("所有 packing 功能测试通过!") \ No newline at end of file diff --git a/tests/run_test.sh b/tests/run_test.sh index 28f77ef..445bddf 100644 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -3,4 +3,5 @@ pytest tests/bumblecore/data_processing/test_data_formatter.py -v -s pytest tests/bumblecore/data_processing/test_datasets.py -v pytest tests/bumblecore/cli/test_arg_parser.py -v -pytest tests/bumblecore/training/test_launcher.py -v \ No newline at end of file +pytest tests/bumblecore/training/test_launcher.py -v +pytest tests/bumblecore/data_processing/test_packing.py -v \ No newline at end of file