From 6f54f6089b927d9c0c994d06e179b1dbf0af87b9 Mon Sep 17 00:00:00 2001 From: JasonCZH4 Date: Sat, 7 Feb 2026 14:47:28 +0000 Subject: [PATCH 1/6] =?UTF-8?q?=E6=94=AF=E6=8C=81packing=E7=AE=97=E6=B3=95?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E9=AB=98=E8=AE=AD=E7=BB=83=E9=80=9F=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bumblecore/config/train_config.py | 2 + src/bumblecore/data_processing/datasets.py | 402 +++++++++++++++++++-- src/bumblecore/training/sft_trainer.py | 2 + 3 files changed, 370 insertions(+), 36 deletions(-) diff --git a/src/bumblecore/config/train_config.py b/src/bumblecore/config/train_config.py index d7b05c8..48dcb3d 100644 --- a/src/bumblecore/config/train_config.py +++ b/src/bumblecore/config/train_config.py @@ -18,6 +18,8 @@ class TrainConfig: weight_decay: float = field(default=0.01) warmup_ratio: float = field(default=0.1) num_epochs: float = field(default=3.0) + packing: bool = field(default=False) + packing_num_proc: int = field(default=1) lr_scheduler_type: str = field(default="cosine") train_micro_batch_size_per_gpu: int = field(default=4) diff --git a/src/bumblecore/data_processing/datasets.py b/src/bumblecore/data_processing/datasets.py index 8b0a13a..167ccfe 100644 --- a/src/bumblecore/data_processing/datasets.py +++ b/src/bumblecore/data_processing/datasets.py @@ -8,6 +8,10 @@ from rich.text import Text from tqdm import tqdm +import math +import multiprocessing as mp +from itertools import chain + def show_sample( input_ids, @@ -119,24 +123,164 @@ def __getitem__(self, idx): return sample +def calculate_matched_group(sequences, packing_length: int, is_finished: bool = True): + """Bin-packing via First Fit Decreasing (https://arxiv.org/pdf/2404.10830).""" + if len(sequences) == 0: + return [], [] + import binpacking + # sequences 是 [(index, length), ...] 列表 + # weight_pos=1 表示长度在元组第二个位置 + # 将一组物品分配到多个容量固定的箱子(bins)中,使得每个箱子的总容量不超过指定的最大值。 + sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1) + # sequences 是列表的列表,每个子列表包含多个 (index, length) 元组 + # 如果不是最后一批,保留最后一个不完整组用于下一批 + if sequences and not is_finished: + sequences, ret_sequences = sequences[:-1], sequences[-1] + else: + ret_sequences = [] + return sequences, ret_sequences + +def split_list(lst, n): + # 划分列表为n个子列表,对应n个子进程处理 + k, m = divmod(len(lst), n) + return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] + + +def _is_master(): + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() == 0 + return True + + +def _is_dist(): + return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 class SFTDataset(Dataset): - + PACKING_BATCH_SIZE = 1000 + def __init__( self, train_dataset, tokenizer, max_length, + # ── new packing args ── + packing: bool = False, + packing_num_proc: int = 1, ): self.train_dataset = train_dataset self.tokenizer = tokenizer self.max_length = max_length - self.has_shown_sample = False - def __len__(self): - return len(self.train_dataset) - + # ── packing bookkeeping ── + self.packing = packing + self.packing_length = max_length + self.packed_idx = None + self.packed_length = None + + if self.packing: + self.packing_num_proc = min( + packing_num_proc, + max(1, math.ceil(len(train_dataset) / self.PACKING_BATCH_SIZE)), + ) + self._out_queue = mp.Queue() + self._setup_packing() + + # ------------------------------------------------------------------ # + # packing index construction # + # ------------------------------------------------------------------ # + + def _compute_lengths(self) -> list[int]: + """Tokenize every sample once to get its length.""" + lengths = [] + for idx in tqdm(range(len(self.train_dataset)), desc="Computing sequence lengths"): + messages = self.train_dataset[idx]["messages"] + tools = self.train_dataset[idx].get("tools", None) + tokens = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=False, + truncation=True, + max_length=self.max_length, + tools=tools if tools else None, + ) + lengths.append(len(tokens)) + return lengths + + def _setup_packing(self): + """Build packed_idx / packed_length using multi-process bin-packing.""" + if _is_master(): + # 计算每条数据的长度 + lengths = self._compute_lengths() + offset = 0 + chunked_lengths = split_list(lengths, self.packing_num_proc) + + # launch workers + for i in range(self.packing_num_proc): + worker = mp.Process( + target=self._create_packed_idx, + args=(i, offset, chunked_lengths[i]), + daemon=True, + ) + worker.start() + offset += len(chunked_lengths[i]) + + # collect results + self.packed_idx = [[] for _ in range(self.packing_num_proc)] + self.packed_length = [[] for _ in range(self.packing_num_proc)] + + desc = ( + "Packing: " + if self.packing_num_proc == 1 + else f"Packing (num_proc={self.packing_num_proc}): " + ) + with tqdm(total=len(lengths), dynamic_ncols=True, desc=desc) as pbar: + finished = 0 + while finished < self.packing_num_proc: + rank, sequences, data_len = self._out_queue.get() + if data_len == -1: # sentinel + finished += 1 + continue + pbar.update(data_len) + # (idx, length) + self.packed_idx[rank] += [[x[0] for x in seq] for seq in sequences] + # sum的结果应该接近packing_length + self.packed_length[rank] += [sum(x[1] for x in seq) for seq in sequences] + + self.packed_idx = list(chain.from_iterable(self.packed_idx)) + self.packed_length = list(chain.from_iterable(self.packed_length)) + else: + self.packed_idx, self.packed_length = None, None + + # broadcast to all ranks + if _is_dist(): + obj_list = [(self.packed_idx, self.packed_length)] + dist.broadcast_object_list(obj_list) + self.packed_idx, self.packed_length = obj_list[0] + + def _create_packed_idx(self, rank: int, offset: int, lengths: list[int]): + """Worker: stream bin-packing results back through self._out_queue.""" + # 这个i + offset 用来定位数据的源数据集中的位置 + data = [(i + offset, length) for i, length in enumerate(lengths)] + i = 0 + input_data: list = [] + while True: + new_data = data[i : i + self.PACKING_BATCH_SIZE] + input_data += new_data + if not input_data: + break + i += self.PACKING_BATCH_SIZE + is_finished = i >= len(data) + sequences, input_data = calculate_matched_group( + input_data, self.packing_length, is_finished=is_finished + ) + # (进程号,packing结果,剩余数据长度) + self._out_queue.put((rank, sequences, len(new_data))) + self._out_queue.put((rank, [], -1)) # sentinel + + # ------------------------------------------------------------------ # + # original SFT logic # + # ------------------------------------------------------------------ # def create_conversation_manually(self, messages, tools): @@ -158,7 +302,6 @@ def create_conversation_manually(self, messages, tools): for i, message in enumerate(messages): if message["role"] == "assistant": - context_with_reply = messages[: i + 1] full_tokens = self.tokenizer.apply_chat_template( context_with_reply, @@ -169,67 +312,254 @@ def create_conversation_manually(self, messages, tools): tools=tools if tools else None, ) reply_end_pos = len(full_tokens) - assistant_masks[current_pos:reply_end_pos] = [1] * (reply_end_pos - current_pos) - else: - - prompt_context = messages[: i + 1] - if message["role"] == "system": continue - - else: - prompt_tokens = self.tokenizer.apply_chat_template( - prompt_context, - tokenize=True, - add_generation_prompt=True, - truncation=True, - max_length=self.max_length, - tools=tools if tools else None, - ) - current_pos = len(prompt_tokens) + prompt_context = messages[: i + 1] + prompt_tokens = self.tokenizer.apply_chat_template( + prompt_context, + tokenize=True, + add_generation_prompt=True, + truncation=True, + max_length=self.max_length, + tools=tools if tools else None, + ) + current_pos = len(prompt_tokens) input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) labels = input_ids.clone() - labels[torch.tensor(assistant_masks, dtype=torch.bool) == 0] = -100 return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - def _show_train_sample(self, input_ids, labels): - if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: - rank = 0 + rank = 0 worker_info = get_worker_info() is_main_worker = (worker_info is None) or (worker_info.id == 0) if rank == 0 and is_main_worker and not self.has_shown_sample: show_sample( input_ids=input_ids, labels=labels, - tokenizer = self.tokenizer, + tokenizer=self.tokenizer, title="SFT Input and Labels", left_column="Input IDs", - right_column="Labels" + right_column="Labels", ) self.has_shown_sample = True - - def __getitem__(self, idx): + + # ------------------------------------------------------------------ # + # __len__ / __getitem__ # + # ------------------------------------------------------------------ # + + def __len__(self): + if self.packing: + return len(self.packed_idx) + return len(self.train_dataset) + + def _process_single_sample(self, idx: int) -> dict: + """Tokenize one sample (shared by normal & packing paths).""" messages = self.train_dataset[idx]["messages"] - tools = self.train_dataset[idx]["tools"] - sample = self.create_conversation_manually(messages, tools) + tools = self.train_dataset[idx].get("tools", None) + return self.create_conversation_manually(messages, tools) - self._show_train_sample( - input_ids=sample["input_ids"], - labels=sample["labels"], - ) + def __getitem__(self, idx): + if self.packing: + return self._getitem_packing(idx) + sample = self._process_single_sample(idx) + self._show_train_sample(input_ids=sample["input_ids"], labels=sample["labels"]) return sample + # ── packing __getitem__ ────────────────────────────────────────────── + + def _getitem_packing(self, idx): + """ + Concatenate the samples assigned to this pack, add per-sequence + position_ids (reset to 0 at each boundary), and pad to + ``packing_length`` so every item in the batch has the same shape. + + Returns + ------- + dict with keys: input_ids, attention_mask, labels, position_ids + All of shape ``(packing_length,)``. + + Notes + ----- + * ``position_ids`` resets to 0 at each sequence boundary, which + Flash-Attention-2 / flex-attention can use to build a + block-diagonal mask automatically. + * Padding tokens get ``label = -100``, ``attention_mask = 0``, + ``position_ids = 0``. + """ + sequence_indices = self.packed_idx[idx] + + all_input_ids = [] + all_labels = [] + all_position_ids = [] + + for seq_idx in sequence_indices: + sample = self._process_single_sample(seq_idx) + input_ids = sample["input_ids"] # (seq_len,) + labels = sample["labels"] # (seq_len,) + seq_len = input_ids.size(0) + + all_input_ids.append(input_ids) + all_labels.append(labels) + all_position_ids.append(torch.arange(seq_len, dtype=torch.long)) + + # concat + input_ids = torch.cat(all_input_ids, dim=0) + labels = torch.cat(all_labels, dim=0) + position_ids = torch.cat(all_position_ids, dim=0) + attention_mask = torch.ones(input_ids.size(0), dtype=torch.long) + + # pad to packing_length + total_len = input_ids.size(0) + if total_len < self.packing_length: + pad_len = self.packing_length - total_len + pad_id = ( + self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None + else 0 + ) + input_ids = torch.cat( + [input_ids, torch.full((pad_len,), pad_id, dtype=torch.long)] + ) + labels = torch.cat( + [labels, torch.full((pad_len,), -100, dtype=torch.long)] + ) + position_ids = torch.cat( + [position_ids, torch.zeros(pad_len, dtype=torch.long)] + ) + attention_mask = torch.cat( + [attention_mask, torch.zeros(pad_len, dtype=torch.long)] + ) + + self._show_train_sample(input_ids=input_ids, labels=labels) + + return dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + ) + + +# class SFTDataset(Dataset): + +# def __init__( +# self, +# train_dataset, +# tokenizer, +# max_length, +# ): +# self.train_dataset = train_dataset +# self.tokenizer = tokenizer +# self.max_length = max_length + +# self.has_shown_sample = False + +# def __len__(self): +# return len(self.train_dataset) + + +# def create_conversation_manually(self, messages, tools): + +# full = self.tokenizer.apply_chat_template( +# messages, +# tokenize=True, +# add_generation_prompt=False, +# return_dict=True, +# truncation=True, +# max_length=self.max_length, +# tools=tools if tools else None, +# ) + +# input_ids = full["input_ids"] +# attention_mask = full["attention_mask"] + +# assistant_masks = [0] * len(input_ids) +# current_pos = 0 + +# for i, message in enumerate(messages): +# if message["role"] == "assistant": + +# context_with_reply = messages[: i + 1] +# full_tokens = self.tokenizer.apply_chat_template( +# context_with_reply, +# tokenize=True, +# add_generation_prompt=False, +# truncation=True, +# max_length=self.max_length, +# tools=tools if tools else None, +# ) +# reply_end_pos = len(full_tokens) + +# assistant_masks[current_pos:reply_end_pos] = [1] * (reply_end_pos - current_pos) + +# else: + +# prompt_context = messages[: i + 1] + +# if message["role"] == "system": +# continue + +# else: +# prompt_tokens = self.tokenizer.apply_chat_template( +# prompt_context, +# tokenize=True, +# add_generation_prompt=True, +# truncation=True, +# max_length=self.max_length, +# tools=tools if tools else None, +# ) +# current_pos = len(prompt_tokens) + +# input_ids = torch.tensor(input_ids, dtype=torch.long) +# attention_mask = torch.tensor(attention_mask, dtype=torch.long) +# labels = input_ids.clone() + +# labels[torch.tensor(assistant_masks, dtype=torch.bool) == 0] = -100 + +# return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + +# def _show_train_sample(self, input_ids, labels): + +# if dist.is_available() and dist.is_initialized(): +# rank = dist.get_rank() +# else: +# rank = 0 +# worker_info = get_worker_info() +# is_main_worker = (worker_info is None) or (worker_info.id == 0) +# if rank == 0 and is_main_worker and not self.has_shown_sample: +# show_sample( +# input_ids=input_ids, +# labels=labels, +# tokenizer = self.tokenizer, +# title="SFT Input and Labels", +# left_column="Input IDs", +# right_column="Labels" +# ) +# self.has_shown_sample = True + +# def __getitem__(self, idx): +# messages = self.train_dataset[idx]["messages"] +# tools = self.train_dataset[idx]["tools"] +# sample = self.create_conversation_manually(messages, tools) + +# self._show_train_sample( +# input_ids=sample["input_ids"], +# labels=sample["labels"], +# ) + +# return sample + class DPODataset(Dataset): def __init__(self, train_dataset, tokenizer, max_length): diff --git a/src/bumblecore/training/sft_trainer.py b/src/bumblecore/training/sft_trainer.py index 8bae3be..6161060 100644 --- a/src/bumblecore/training/sft_trainer.py +++ b/src/bumblecore/training/sft_trainer.py @@ -29,5 +29,7 @@ def _prepare_datasets(self): messages, tokenizer, max_length=self.config.cutoff_len, + packing=self.config.packing, + packing_num_proc=self.config.packing_num_proc, ) return tokenizer, train_dataset From 5ea95d339d0abefa5da3a88c97741925d4e0c8a2 Mon Sep 17 00:00:00 2001 From: JasonCZH <74161960+JasonCZH4@users.noreply.github.com> Date: Sat, 7 Feb 2026 22:56:55 +0800 Subject: [PATCH 2/6] Remove commented-out SFTDataset class Removed commented-out SFTDataset class and its methods. --- src/bumblecore/data_processing/datasets.py | 113 +-------------------- 1 file changed, 1 insertion(+), 112 deletions(-) diff --git a/src/bumblecore/data_processing/datasets.py b/src/bumblecore/data_processing/datasets.py index 167ccfe..f335886 100644 --- a/src/bumblecore/data_processing/datasets.py +++ b/src/bumblecore/data_processing/datasets.py @@ -450,117 +450,6 @@ def _getitem_packing(self, idx): ) -# class SFTDataset(Dataset): - -# def __init__( -# self, -# train_dataset, -# tokenizer, -# max_length, -# ): -# self.train_dataset = train_dataset -# self.tokenizer = tokenizer -# self.max_length = max_length - -# self.has_shown_sample = False - -# def __len__(self): -# return len(self.train_dataset) - - -# def create_conversation_manually(self, messages, tools): - -# full = self.tokenizer.apply_chat_template( -# messages, -# tokenize=True, -# add_generation_prompt=False, -# return_dict=True, -# truncation=True, -# max_length=self.max_length, -# tools=tools if tools else None, -# ) - -# input_ids = full["input_ids"] -# attention_mask = full["attention_mask"] - -# assistant_masks = [0] * len(input_ids) -# current_pos = 0 - -# for i, message in enumerate(messages): -# if message["role"] == "assistant": - -# context_with_reply = messages[: i + 1] -# full_tokens = self.tokenizer.apply_chat_template( -# context_with_reply, -# tokenize=True, -# add_generation_prompt=False, -# truncation=True, -# max_length=self.max_length, -# tools=tools if tools else None, -# ) -# reply_end_pos = len(full_tokens) - -# assistant_masks[current_pos:reply_end_pos] = [1] * (reply_end_pos - current_pos) - -# else: - -# prompt_context = messages[: i + 1] - -# if message["role"] == "system": -# continue - -# else: -# prompt_tokens = self.tokenizer.apply_chat_template( -# prompt_context, -# tokenize=True, -# add_generation_prompt=True, -# truncation=True, -# max_length=self.max_length, -# tools=tools if tools else None, -# ) -# current_pos = len(prompt_tokens) - -# input_ids = torch.tensor(input_ids, dtype=torch.long) -# attention_mask = torch.tensor(attention_mask, dtype=torch.long) -# labels = input_ids.clone() - -# labels[torch.tensor(assistant_masks, dtype=torch.bool) == 0] = -100 - -# return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - - -# def _show_train_sample(self, input_ids, labels): - -# if dist.is_available() and dist.is_initialized(): -# rank = dist.get_rank() -# else: -# rank = 0 -# worker_info = get_worker_info() -# is_main_worker = (worker_info is None) or (worker_info.id == 0) -# if rank == 0 and is_main_worker and not self.has_shown_sample: -# show_sample( -# input_ids=input_ids, -# labels=labels, -# tokenizer = self.tokenizer, -# title="SFT Input and Labels", -# left_column="Input IDs", -# right_column="Labels" -# ) -# self.has_shown_sample = True - -# def __getitem__(self, idx): -# messages = self.train_dataset[idx]["messages"] -# tools = self.train_dataset[idx]["tools"] -# sample = self.create_conversation_manually(messages, tools) - -# self._show_train_sample( -# input_ids=sample["input_ids"], -# labels=sample["labels"], -# ) - -# return sample - - class DPODataset(Dataset): def __init__(self, train_dataset, tokenizer, max_length): self.train_dataset = train_dataset @@ -739,4 +628,4 @@ def _right_pad_to_len(sequences, max_length, padding_value): device=padded.device ) padded = torch.cat([padded, pad_tensor], dim=1) - return padded \ No newline at end of file + return padded From 470021bfba7e856ab49e230fd27409ddd0fead7b Mon Sep 17 00:00:00 2001 From: 19742021 <2458807275@qq.com> Date: Mon, 9 Feb 2026 22:10:46 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E6=8F=90=E5=8D=87=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=81=A5=E5=A3=AE=E6=80=A7=EF=BC=8CHfDeepSpeedConfig=E5=8C=85?= =?UTF-8?q?=E5=90=AB=E4=BA=86zero.Init=E7=9A=84=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E5=8E=BB=E9=99=A4=E9=87=8D=E5=8F=A0=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bumblecore/training/base_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/bumblecore/training/base_trainer.py b/src/bumblecore/training/base_trainer.py index 5f7a3b3..472b49a 100644 --- a/src/bumblecore/training/base_trainer.py +++ b/src/bumblecore/training/base_trainer.py @@ -171,8 +171,8 @@ def _instantiate_model(): if zero_stage == 3 and self.config.training_stage != "pretrain": self.dschf = HfDeepSpeedConfig(self.deepspeed_config) - with deepspeed.zero.Init(config_dict_or_path=self.deepspeed_config): - return _instantiate_model() + # with deepspeed.zero.Init(config_dict_or_path=self.deepspeed_config): + return _instantiate_model() else: return _instantiate_model() @@ -520,7 +520,9 @@ def _reset_accumulated_results(self): def _handle_logging_and_progress(self, global_step, computation_result): if global_step % self.config.logging_steps == 0: lr = self.model_engine.get_lr()[0] - grad_norm = self.model_engine.get_global_grad_norm().item() + grad_norm = self.model_engine.get_global_grad_norm() + if isinstance(grad_norm, torch.Tensor): + grad_norm = grad_norm.item() epochs_done = min(global_step / self.num_update_steps_per_epoch, math.ceil(self.config.num_epochs)) self.log_metrics(lr, grad_norm, epochs_done, global_step, computation_result) From 17d7cb61bc529f490164e4a04494a3395992a7a7 Mon Sep 17 00:00:00 2001 From: JasonCZH4 Date: Fri, 13 Feb 2026 10:22:32 +0000 Subject: [PATCH 4/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=8E=9F=E6=9D=A5packing?= =?UTF-8?q?=E7=9A=84bug=EF=BC=8C=E5=8A=A0=E5=85=A5test=5Fpacking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bumblecore/cli/arg_parser.py | 2 + src/bumblecore/config/train_config.py | 6 +- src/bumblecore/data_processing/__init__.py | 27 +- .../data_processing/dataset_utils.py | 114 ++++ src/bumblecore/data_processing/datasets.py | 190 +++---- src/bumblecore/training/base_trainer.py | 12 +- src/bumblecore/training/sft_trainer.py | 8 +- .../data_processing/test_packing.py | 538 ++++++++++++++++++ 8 files changed, 782 insertions(+), 115 deletions(-) create mode 100644 src/bumblecore/data_processing/dataset_utils.py create mode 100644 tests/bumblecore/data_processing/test_packing.py diff --git a/src/bumblecore/cli/arg_parser.py b/src/bumblecore/cli/arg_parser.py index e8ddaf1..af8c614 100644 --- a/src/bumblecore/cli/arg_parser.py +++ b/src/bumblecore/cli/arg_parser.py @@ -47,6 +47,8 @@ def get_args(): action="store_true", default=cfg.get("enable_gradient_checkpointing", False), ) + parser.add_argument("--packing", type=bool, default=cfg.get("packing", True)) + parser.add_argument("--packing_num_proc", type=int, default=cfg.get("packing_num_proc", 4)) parser.add_argument("--warmup_ratio", type=float, default=cfg.get("warmup_ratio", 0.1)) # Batch Size diff --git a/src/bumblecore/config/train_config.py b/src/bumblecore/config/train_config.py index 48dcb3d..f202d53 100644 --- a/src/bumblecore/config/train_config.py +++ b/src/bumblecore/config/train_config.py @@ -18,8 +18,6 @@ class TrainConfig: weight_decay: float = field(default=0.01) warmup_ratio: float = field(default=0.1) num_epochs: float = field(default=3.0) - packing: bool = field(default=False) - packing_num_proc: int = field(default=1) lr_scheduler_type: str = field(default="cosine") train_micro_batch_size_per_gpu: int = field(default=4) @@ -48,6 +46,10 @@ class TrainConfig: lora_dropout: float = field(default=0.1) lora_target_modules: Optional[Union[List[str], str]] = field(default=None) + # Packing settings + packing: bool = field(default=False) + packing_num_proc: int = field(default=1) + ld_alpha: float = field(default=1.0) pref_beta: float = field(default=0.1) diff --git a/src/bumblecore/data_processing/__init__.py b/src/bumblecore/data_processing/__init__.py index f8d24a9..f229e6a 100644 --- a/src/bumblecore/data_processing/__init__.py +++ b/src/bumblecore/data_processing/__init__.py @@ -1,15 +1,34 @@ -from .datasets import SFTDataset,PretrainDataset,DataCollator,DPODataset,DPOCollator +from .datasets import ( + SFTDataset, PretrainDataset, DataCollator, DPODataset, DPOCollator, + PackingDataCollator, +) from .preprocess import load_pretrain_data,load_sft_data,load_dpo_data from .data_formatter import DataFormatter +from .dataset_utils import ( + show_sample, + get_padding_value, + calculate_matched_group, + split_list, + is_master, + is_distributed, +) __all__ = [ - "SFTDataset", - "PretrainDataset", + "SFTDataset", + "PretrainDataset", "DataCollator", "load_pretrain_data", "load_sft_data", "load_dpo_data", "DataFormatter", "DPODataset", - "DPOCollator" + "DPOCollator", + "PackingDataCollator", + # Utility functions + "show_sample", + "get_padding_value", + "calculate_matched_group", + "split_list", + "is_master", + "is_distributed", ] diff --git a/src/bumblecore/data_processing/dataset_utils.py b/src/bumblecore/data_processing/dataset_utils.py new file mode 100644 index 0000000..dae0788 --- /dev/null +++ b/src/bumblecore/data_processing/dataset_utils.py @@ -0,0 +1,114 @@ +"""Utility functions for dataset processing.""" + +import io +from typing import List, Tuple + +import torch +import torch.distributed as dist +from rich.console import Console +from rich.table import Table +from rich.text import Text +from tqdm import tqdm + + +def show_sample( + input_ids, + labels, + tokenizer, + title="Input and Labels", + left_column="Input IDs", + right_column="Labels" +): + """Display a sample with input_ids and labels in a formatted table.""" + input_ids = input_ids.tolist() + labels = labels.tolist() + + valid_labels_list = [token_id for token_id in labels if token_id != -100] + decoded_input = tokenizer.decode(input_ids) + decoded_labels = tokenizer.decode(valid_labels_list) + + table = Table(show_header=True, show_lines=True, title=title) + table.add_column(left_column, overflow="fold") + table.add_column(right_column, overflow="fold") + + wrapped_input = Text(decoded_input, no_wrap=False, overflow="fold") + wrapped_labels = Text(decoded_labels, no_wrap=False, overflow="fold") + + table.add_row(str(input_ids), str(labels)) + table.add_row(wrapped_input, wrapped_labels) + + with io.StringIO() as buf: + console = Console(file=buf, force_terminal=False) + console.print(table) + output = buf.getvalue() + + tqdm.write(output.rstrip()) + + +def get_padding_value(tokenizer): + """Get the padding token id from tokenizer. + + If pad_token_id is not set, use eos_token_id as fallback. + """ + if tokenizer.pad_token_id is not None: + return tokenizer.pad_token_id + + eos = tokenizer.eos_token_id + return eos[0] if isinstance(eos, list) else eos + + +def calculate_matched_group(sequences: List[Tuple[int, int]], packing_length: int, is_finished: bool = True): + """Bin-packing via First Fit Decreasing (https://arxiv.org/pdf/2404.10830). + + Args: + sequences: List of (index, length) tuples. + packing_length: Maximum length for each pack. + is_finished: Whether this is the last batch. + + Returns: + Tuple of (packed_sequences, remaining_sequences). + packed_sequences is a list of lists, each containing (index, length) tuples. + """ + if len(sequences) == 0: + return [], [] + import binpacking + + # sequences 是 [(index, length), ...] 列表 + # weight_pos=1 表示长度在元组第二个位置 + # 将一组物品分配到多个容量固定的箱子(bins)中,使得每个箱子的总容量不超过指定的最大值。 + sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1) + + # sequences 是列表的列表,每个子列表包含多个 (index, length) 元组 + # 如果不是最后一批,保留最后一个不完整组用于下一批 + if sequences and not is_finished: + sequences, ret_sequences = sequences[:-1], sequences[-1] + else: + ret_sequences = [] + return sequences, ret_sequences + + +def split_list(lst: list, n: int) -> List[list]: + """Split a list into n sublists as evenly as possible. + + Args: + lst: The list to split. + n: Number of parts to split into. + + Returns: + List of n sublists. + """ + # 划分列表为n个子列表,对应n个子进程处理 + k, m = divmod(len(lst), n) + return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] + + +def is_master() -> bool: + """Check if current process is the master process in distributed training.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() == 0 + return True + + +def is_distributed() -> bool: + """Check if running in distributed training mode.""" + return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 diff --git a/src/bumblecore/data_processing/datasets.py b/src/bumblecore/data_processing/datasets.py index f335886..9fef5a8 100644 --- a/src/bumblecore/data_processing/datasets.py +++ b/src/bumblecore/data_processing/datasets.py @@ -1,57 +1,20 @@ -import io - import torch import torch.distributed as dist -from torch.utils.data import Dataset,get_worker_info -from rich.console import Console -from rich.table import Table -from rich.text import Text +from torch.utils.data import Dataset, get_worker_info from tqdm import tqdm import math import multiprocessing as mp from itertools import chain - -def show_sample( - input_ids, - labels, - tokenizer, - title="Input and Labels" , - left_column = "Input IDs", - right_column = "Labels" -): - input_ids = input_ids.tolist() - labels = labels.tolist() - - valid_labels_list = [token_id for token_id in labels if token_id != -100] - decoded_input = tokenizer.decode(input_ids) - decoded_labels = tokenizer.decode(valid_labels_list) - - table = Table(show_header=True, show_lines=True, title=title) - table.add_column(left_column, overflow="fold") - table.add_column(right_column, overflow="fold") - - wrapped_input = Text(decoded_input, no_wrap=False, overflow="fold") - wrapped_labels = Text(decoded_labels, no_wrap=False, overflow="fold") - - table.add_row(str(input_ids), str(labels)) - table.add_row(wrapped_input, wrapped_labels) - - with io.StringIO() as buf: - console = Console(file=buf, force_terminal=False) - console.print(table) - output = buf.getvalue() - - tqdm.write(output.rstrip()) - - -def get_padding_value(tokenizer): - if tokenizer.pad_token_id is not None: - return tokenizer.pad_token_id - - eos = tokenizer.eos_token_id - return eos[0] if isinstance(eos, list) else eos +from .dataset_utils import ( + show_sample, + get_padding_value, + calculate_matched_group, + split_list, + is_master, + is_distributed, +) class PretrainDataset(Dataset): @@ -123,38 +86,8 @@ def __getitem__(self, idx): return sample -def calculate_matched_group(sequences, packing_length: int, is_finished: bool = True): - """Bin-packing via First Fit Decreasing (https://arxiv.org/pdf/2404.10830).""" - if len(sequences) == 0: - return [], [] - import binpacking - # sequences 是 [(index, length), ...] 列表 - # weight_pos=1 表示长度在元组第二个位置 - # 将一组物品分配到多个容量固定的箱子(bins)中,使得每个箱子的总容量不超过指定的最大值。 - sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1) - # sequences 是列表的列表,每个子列表包含多个 (index, length) 元组 - # 如果不是最后一批,保留最后一个不完整组用于下一批 - if sequences and not is_finished: - sequences, ret_sequences = sequences[:-1], sequences[-1] - else: - ret_sequences = [] - return sequences, ret_sequences - -def split_list(lst, n): - # 划分列表为n个子列表,对应n个子进程处理 - k, m = divmod(len(lst), n) - return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] -def _is_master(): - if dist.is_available() and dist.is_initialized(): - return dist.get_rank() == 0 - return True - - -def _is_dist(): - return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 - class SFTDataset(Dataset): PACKING_BATCH_SIZE = 1000 @@ -172,6 +105,9 @@ def __init__( self.max_length = max_length self.has_shown_sample = False + if len(train_dataset) == 0: + raise ValueError("train_dataset cannot be empty") + # ── packing bookkeeping ── self.packing = packing self.packing_length = max_length @@ -209,7 +145,7 @@ def _compute_lengths(self) -> list[int]: def _setup_packing(self): """Build packed_idx / packed_length using multi-process bin-packing.""" - if _is_master(): + if is_master(): # 计算每条数据的长度 lengths = self._compute_lengths() offset = 0 @@ -253,7 +189,7 @@ def _setup_packing(self): self.packed_idx, self.packed_length = None, None # broadcast to all ranks - if _is_dist(): + if is_distributed(): obj_list = [(self.packed_idx, self.packed_length)] dist.broadcast_object_list(obj_list) self.packed_idx, self.packed_length = obj_list[0] @@ -372,7 +308,7 @@ def __getitem__(self, idx): return self._getitem_packing(idx) sample = self._process_single_sample(idx) - self._show_train_sample(input_ids=sample["input_ids"], labels=sample["labels"]) + # self._show_train_sample(input_ids=sample["input_ids"], labels=sample["labels"]) return sample # ── packing __getitem__ ────────────────────────────────────────────── @@ -380,21 +316,19 @@ def __getitem__(self, idx): def _getitem_packing(self, idx): """ Concatenate the samples assigned to this pack, add per-sequence - position_ids (reset to 0 at each boundary), and pad to - ``packing_length`` so every item in the batch has the same shape. + position_ids (reset to 0 at each boundary). Returns ------- dict with keys: input_ids, attention_mask, labels, position_ids - All of shape ``(packing_length,)``. + All tensors are concatenated but NOT padded. + Padding is handled by PackingDataCollator. Notes ----- * ``position_ids`` resets to 0 at each sequence boundary, which Flash-Attention-2 / flex-attention can use to build a block-diagonal mask automatically. - * Padding tokens get ``label = -100``, ``attention_mask = 0``, - ``position_ids = 0``. """ sequence_indices = self.packed_idx[idx] @@ -418,28 +352,6 @@ def _getitem_packing(self, idx): position_ids = torch.cat(all_position_ids, dim=0) attention_mask = torch.ones(input_ids.size(0), dtype=torch.long) - # pad to packing_length - total_len = input_ids.size(0) - if total_len < self.packing_length: - pad_len = self.packing_length - total_len - pad_id = ( - self.tokenizer.pad_token_id - if self.tokenizer.pad_token_id is not None - else 0 - ) - input_ids = torch.cat( - [input_ids, torch.full((pad_len,), pad_id, dtype=torch.long)] - ) - labels = torch.cat( - [labels, torch.full((pad_len,), -100, dtype=torch.long)] - ) - position_ids = torch.cat( - [position_ids, torch.zeros(pad_len, dtype=torch.long)] - ) - attention_mask = torch.cat( - [attention_mask, torch.zeros(pad_len, dtype=torch.long)] - ) - self._show_train_sample(input_ids=input_ids, labels=labels) return dict( @@ -629,3 +541,69 @@ def _right_pad_to_len(sequences, max_length, padding_value): ) padded = torch.cat([padded, pad_tensor], dim=1) return padded + + +class PackingDataCollator: + """ + Data collator for packed sequences. + + Pads a batch of packed sequences (with varying lengths) to the max length + in the batch. Similar to DataCollator but also handles position_ids. + """ + + def __init__(self, tokenizer): + self.input_ids_padding_value = get_padding_value(tokenizer=tokenizer) + + def __call__(self, batch): + if not batch: + return dict( + input_ids=torch.tensor([], dtype=torch.long).reshape(0, 0), + attention_mask=torch.tensor([], dtype=torch.long).reshape(0, 0), + labels=torch.tensor([], dtype=torch.long).reshape(0, 0), + position_ids=torch.tensor([], dtype=torch.long).reshape(0, 0), + ) + + input_ids = [item["input_ids"] for item in batch] + attention_mask = [item["attention_mask"] for item in batch] + labels = [item["labels"] for item in batch] + position_ids = [item["position_ids"] for item in batch] + + # Find max length in this batch + max_length = max(len(x) for x in input_ids) + + # Pad all sequences to max_length + input_ids = self._right_pad_to_len( + input_ids, max_length, self.input_ids_padding_value + ) + attention_mask = self._right_pad_to_len( + attention_mask, max_length, 0 + ) + labels = self._right_pad_to_len( + labels, max_length, -100 + ) + position_ids = self._right_pad_to_len( + position_ids, max_length, 0 + ) + + return dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=position_ids, + ) + + @staticmethod + def _right_pad_to_len(sequences, max_length, padding_value): + padded = torch.nn.utils.rnn.pad_sequence( + sequences, batch_first=True, padding_value=padding_value + ) + if padded.size(1) < max_length: + diff = max_length - padded.size(1) + pad_tensor = torch.full( + (padded.size(0), diff), + padding_value, + dtype=padded.dtype, + device=padded.device + ) + padded = torch.cat([padded, pad_tensor], dim=1) + return padded diff --git a/src/bumblecore/training/base_trainer.py b/src/bumblecore/training/base_trainer.py index 472b49a..bcf69b4 100644 --- a/src/bumblecore/training/base_trainer.py +++ b/src/bumblecore/training/base_trainer.py @@ -542,7 +542,17 @@ def compute_loss(self, model_engine, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] - output = model_engine(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + + # Support packing: pass position_ids if present + model_kwargs = dict( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + if "position_ids" in batch: + model_kwargs["position_ids"] = batch["position_ids"] + + output = model_engine(**model_kwargs) if self.config.average_tokens_across_devices: count = (labels != -100).sum().float() dist.all_reduce(count, op=dist.ReduceOp.SUM) diff --git a/src/bumblecore/training/sft_trainer.py b/src/bumblecore/training/sft_trainer.py index 6161060..6e91489 100644 --- a/src/bumblecore/training/sft_trainer.py +++ b/src/bumblecore/training/sft_trainer.py @@ -1,7 +1,7 @@ from transformers import AutoTokenizer from .base_trainer import BaseTrainer -from ..data_processing import DataFormatter, SFTDataset, load_sft_data, DataCollator +from ..data_processing import DataFormatter, SFTDataset, load_sft_data, DataCollator, PackingDataCollator from ..config import TrainConfig class SFTTrainer(BaseTrainer): @@ -9,7 +9,11 @@ def __init__(self, config: TrainConfig): self.config = config self.format_preprocess_fn = DataFormatter(self.config.training_stage) self.tokenizer, self.train_dataset = self._prepare_datasets() - self.data_collator = DataCollator(self.tokenizer) + # Use PackingDataCollator when packing is enabled, otherwise use standard DataCollator + if self.config.packing: + self.data_collator = PackingDataCollator(self.tokenizer) + else: + self.data_collator = DataCollator(self.tokenizer) super().__init__(config, self.train_dataset, self.tokenizer, self.data_collator) self._print_train_parameters() diff --git a/tests/bumblecore/data_processing/test_packing.py b/tests/bumblecore/data_processing/test_packing.py new file mode 100644 index 0000000..5c16a95 --- /dev/null +++ b/tests/bumblecore/data_processing/test_packing.py @@ -0,0 +1,538 @@ +import torch +import pytest +from transformers import AutoTokenizer +from bumblecore.data_processing import ( + SFTDataset, + DataCollator, + PackingDataCollator, +) + +tokenizer = AutoTokenizer.from_pretrained("./models/bumblebee") + + +# ============================== +# PackingDataCollator 测试 +# ============================== + +def test_packing_data_collator_basic(): + """测试 PackingDataCollator 的基本功能""" + collator = PackingDataCollator(tokenizer) + + # 创建包含 position_ids 的批次数据 + batch = [ + { + "input_ids": torch.tensor([1, 10, 20, 30, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1, 1]), + "labels": torch.tensor([1, 10, 20, 30, 2]), + "position_ids": torch.tensor([0, 1, 2, 3, 4]), + }, + { + "input_ids": torch.tensor([1, 15, 25, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1]), + "labels": torch.tensor([1, 15, 25, 2]), + "position_ids": torch.tensor([0, 1, 2, 3]), + }, + ] + + result = collator(batch) + + pad_token_id = tokenizer.pad_token_id + expected = { + "input_ids": torch.tensor([ + [1, 10, 20, 30, 2], + [1, 15, 25, 2, pad_token_id], + ], dtype=torch.long), + "attention_mask": torch.tensor([ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 0], + ], dtype=torch.long), + "labels": torch.tensor([ + [1, 10, 20, 30, 2], + [1, 15, 25, 2, -100], + ], dtype=torch.long), + "position_ids": torch.tensor([ + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 0], + ], dtype=torch.long), + } + + assert torch.equal(result["input_ids"], expected["input_ids"]), f"input_ids 不匹配, 结果: {result['input_ids']}, 期望: {expected['input_ids']}" + assert torch.equal(result["attention_mask"], expected["attention_mask"]), "attention_mask 不匹配" + assert torch.equal(result["labels"], expected["labels"]), "labels 不匹配" + assert torch.equal(result["position_ids"], expected["position_ids"]), "position_ids 不匹配" + + +def test_packing_data_collator_with_varying_lengths(): + """测试 PackingDataCollator 处理不同长度序列""" + collator = PackingDataCollator(tokenizer) + + batch = [ + { + "input_ids": torch.tensor([1, 10, 2]), + "attention_mask": torch.tensor([1, 1, 1]), + "labels": torch.tensor([-100, 10, 2]), + "position_ids": torch.tensor([0, 1, 2]), + }, + { + "input_ids": torch.tensor([1, 15, 25, 35, 45, 2]), + "attention_mask": torch.tensor([1, 1, 1, 1, 1, 1]), + "labels": torch.tensor([-100, -100, 25, 35, 45, 2]), + "position_ids": torch.tensor([0, 1, 2, 3, 4, 5]), + }, + { + "input_ids": torch.tensor([1, 12, 2]), + "attention_mask": torch.tensor([1, 1, 1]), + "labels": torch.tensor([-100, 12, 2]), + "position_ids": torch.tensor([0, 1, 2]), + }, + ] + + result = collator(batch) + + pad_token_id = tokenizer.pad_token_id + expected = { + "input_ids": torch.tensor([ + [1, 10, 2, pad_token_id, pad_token_id, pad_token_id], + [1, 15, 25, 35, 45, 2], + [1, 12, 2, pad_token_id, pad_token_id, pad_token_id], + ], dtype=torch.long), + "attention_mask": torch.tensor([ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0], + ], dtype=torch.long), + "labels": torch.tensor([ + [-100, 10, 2, -100, -100, -100], + [-100, -100, 25, 35, 45, 2], + [-100, 12, 2, -100, -100, -100], + ], dtype=torch.long), + "position_ids": torch.tensor([ + [0, 1, 2, 0, 0, 0], + [0, 1, 2, 3, 4, 5], + [0, 1, 2, 0, 0, 0], + ], dtype=torch.long), + } + + assert torch.equal(result["input_ids"], expected["input_ids"]) + assert torch.equal(result["attention_mask"], expected["attention_mask"]) + assert torch.equal(result["labels"], expected["labels"]) + assert torch.equal(result["position_ids"], expected["position_ids"]) + + +def test_packing_data_collator_empty_batch(): + """测试 PackingDataCollator 处理空批次""" + collator = PackingDataCollator(tokenizer) + + batch = [] + + result = collator(batch) + + assert result["input_ids"].shape == (0, 0) + assert result["attention_mask"].shape == (0, 0) + assert result["labels"].shape == (0, 0) + assert result["position_ids"].shape == (0, 0) + + +# ============================== +# SFTDataset Packing 功能测试 +# ============================== + +def test_sft_dataset_packing_disabled(): + """测试 SFTDataset 禁用 packing 时的行为""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 禁用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=False) + + # 检查长度 + assert len(dataset) == len(train_dataset) + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" not in result # 禁用 packing 时不返回 position_ids + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + + +def test_sft_dataset_packing_enabled(): + """测试 SFTDataset 启用 packing 时的基本行为""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度(packing 会改变数据集长度) + assert len(dataset) >= 1 # 至少有一个pack + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result # 启用 packing 时返回 position_ids + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + assert isinstance(result["position_ids"], torch.Tensor) + + # 检查序列长度一致性 + input_ids_len = len(result["input_ids"]) + assert len(result["attention_mask"]) == input_ids_len + assert len(result["labels"]) == input_ids_len + assert len(result["position_ids"]) == input_ids_len + + +def test_sft_dataset_packing_with_single_sample(): + """测试 SFTDataset packing 处理单个样本""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度 + assert len(dataset) == 1 # 单个样本应该只有一个pack + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查 position_ids 是否正确重置 + position_ids = result["position_ids"] + assert torch.equal(position_ids, torch.arange(len(position_ids), dtype=torch.long)) + + +def test_sft_dataset_packing_position_ids_reset(): + """测试 SFTDataset packing 中 position_ids 的重置逻辑""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "Short message."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "Another short message."}, + {"role": "user", "content": "Hey"}, + {"role": "assistant", "content": "Hi there"}, + ], + "tools": None, + } + ] + + max_length = 100 # 设置较小的max_length以触发packing + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 获取样本 + result = dataset[0] + + # 检查 position_ids 是否正确重置 + position_ids = result["position_ids"] + + # 获取 packed_idx 中的实际顺序 + sequence_indices = dataset.packed_idx[0] + + # 根据实际顺序验证 position_ids + offset = 0 + for seq_idx in sequence_indices: + sample = dataset._process_single_sample(seq_idx) + seq_len = len(sample["input_ids"]) + + # 验证当前序列的 position_ids 正确重置为 [0, 1, ..., seq_len-1] + expected_position_ids = torch.arange(seq_len, dtype=torch.long) + actual_position_ids = position_ids[offset:offset + seq_len] + assert torch.equal(actual_position_ids, expected_position_ids), \ + f"Sequence {seq_idx}: expected {expected_position_ids.tolist()}, got {actual_position_ids.tolist()}" + + offset += seq_len + + # 验证总长度正确 + assert offset == len(position_ids), f"Total length mismatch: {offset} vs {len(position_ids)}" + + +def test_sft_dataset_packing_with_tools(): + """测试 SFTDataset packing 带 tools 的情况""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Calculate 2+2"}, + {"role": "assistant", "content": "4"}, + ], + "tools": [{"name": "calculator"}], + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 3*3?"}, + {"role": "assistant", "content": "9"}, + ], + "tools": [{"name": "calculator"}], + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 检查长度 + assert len(dataset) >= 1 + + # 获取样本 + result = dataset[0] + + # 检查返回的字段 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查数据类型 + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert isinstance(result["labels"], torch.Tensor) + assert isinstance(result["position_ids"], torch.Tensor) + + +def test_sft_dataset_packing_labels_consistency(): + """测试 SFTDataset packing 中 labels 的一致性""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 启用 packing + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 获取样本 + result = dataset[0] + + input_ids = result["input_ids"] + labels = result["labels"] + attention_mask = result["attention_mask"] + + # 检查 labels 中非 -100 的位置与 input_ids 一致 + non_negative_100_mask = labels != -100 + assert torch.equal(input_ids[non_negative_100_mask], labels[non_negative_100_mask]) + + # 检查 attention_mask 的有效性 + assert torch.all(attention_mask == 1) # packing 时 attention_mask 应该全为1 + + +def test_sft_dataset_packing_edge_cases(): + """测试 SFTDataset packing 的边缘情况""" + + # 测试空数据集 + empty_dataset = [] + with pytest.raises(ValueError, match="train_dataset cannot be empty"): + dataset = SFTDataset(empty_dataset, tokenizer, max_length=256, packing=True) + + # 测试单个长样本 + long_message_dataset = [ + { + "messages": [ + {"role": "system", "content": "A" * 1000}, # 很长的消息 + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 50 # 设置较小的max_length + + dataset = SFTDataset(long_message_dataset, tokenizer, max_length, packing=True) + result = dataset[0] + + # 检查序列长度不超过max_length + assert len(result["input_ids"]) <= max_length + + +# ============================== +# 集成测试:PackingDataCollator + SFTDataset +# ============================== + +def test_packing_integration(): + """测试 PackingDataCollator 和 SFTDataset 的集成""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + }, + { + "messages": [ + {"role": "system", "content": "You are a math tutor."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 创建带 packing 的 dataset + dataset = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + + # 创建 PackingDataCollator + collator = PackingDataCollator(tokenizer) + + # 创建批次 + batch = [dataset[i] for i in range(min(2, len(dataset)))] + + # 使用 collator 处理批次 + result = collator(batch) + + # 检查结果 + assert "input_ids" in result + assert "attention_mask" in result + assert "labels" in result + assert "position_ids" in result + + # 检查批次维度 + batch_size = len(batch) + assert result["input_ids"].shape[0] == batch_size + assert result["attention_mask"].shape[0] == batch_size + assert result["labels"].shape[0] == batch_size + assert result["position_ids"].shape[0] == batch_size + + +def test_packing_vs_standard_collator(): + """比较 PackingDataCollator 和标准 DataCollator 的区别""" + train_dataset = [ + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "tools": None, + } + ] + + max_length = 256 + + # 创建带 packing 的 dataset + dataset_packing = SFTDataset(train_dataset, tokenizer, max_length, packing=True) + dataset_standard = SFTDataset(train_dataset, tokenizer, max_length, packing=False) + + # 创建两种 collator + packing_collator = PackingDataCollator(tokenizer) + standard_collator = DataCollator(tokenizer) + + # 获取样本 + sample_packing = dataset_packing[0] + sample_standard = dataset_standard[0] + + # 检查 packing 样本包含 position_ids + assert "position_ids" in sample_packing + assert "position_ids" not in sample_standard + + # 使用 collator 处理 + batch_packing = packing_collator([sample_packing]) + batch_standard = standard_collator([sample_standard]) + + # 检查 packing collator 返回 position_ids + assert "position_ids" in batch_packing + assert "position_ids" not in batch_standard + + +if __name__ == "__main__": + # 运行测试 + test_packing_data_collator_basic() + test_packing_data_collator_with_varying_lengths() + test_packing_data_collator_empty_batch() + test_sft_dataset_packing_disabled() + test_sft_dataset_packing_enabled() + test_sft_dataset_packing_with_single_sample() + test_sft_dataset_packing_position_ids_reset() + test_sft_dataset_packing_with_tools() + test_sft_dataset_packing_labels_consistency() + test_sft_dataset_packing_edge_cases() + test_packing_integration() + test_packing_vs_standard_collator() + + print("所有 packing 功能测试通过!") \ No newline at end of file From 0485a612326bd39f6fe6016282932c1699302fbf Mon Sep 17 00:00:00 2001 From: JasonCZH4 Date: Fri, 13 Feb 2026 11:56:00 +0000 Subject: [PATCH 5/6] =?UTF-8?q?=E6=B7=BB=E5=8A=A0test=5Fpacking=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E4=BB=A3=E7=A0=81=EF=BC=8C=E5=BF=BD=E7=95=A5png?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + tests/run_test.sh | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index cbda94f..b7d5fc9 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,4 @@ experiments/*.md # DeepSpeed deepspeed_logs/ +*.png diff --git a/tests/run_test.sh b/tests/run_test.sh index 28f77ef..445bddf 100644 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -3,4 +3,5 @@ pytest tests/bumblecore/data_processing/test_data_formatter.py -v -s pytest tests/bumblecore/data_processing/test_datasets.py -v pytest tests/bumblecore/cli/test_arg_parser.py -v -pytest tests/bumblecore/training/test_launcher.py -v \ No newline at end of file +pytest tests/bumblecore/training/test_launcher.py -v +pytest tests/bumblecore/data_processing/test_packing.py -v \ No newline at end of file From 1da4cea4b301c86cd69a24788ef52a7d44ba1936 Mon Sep 17 00:00:00 2001 From: JasonCZH4 Date: Fri, 13 Feb 2026 12:07:08 +0000 Subject: [PATCH 6/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9packing=E7=9B=B8=E5=85=B3?= =?UTF-8?q?yaml=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/sft/sft_full.yaml | 2 ++ configs/sft/sft_lora.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/configs/sft/sft_full.yaml b/configs/sft/sft_full.yaml index ad2ac66..4bb44e6 100644 --- a/configs/sft/sft_full.yaml +++ b/configs/sft/sft_full.yaml @@ -18,6 +18,8 @@ weight_decay: 0.01 lr_scheduler_type: "cosine" enable_gradient_checkpointing: true warmup_ratio: 0.1 +packing: true +packing_num_proc: 4 # Batch Size train_micro_batch_size_per_gpu: 4 diff --git a/configs/sft/sft_lora.yaml b/configs/sft/sft_lora.yaml index b056311..c77ae4b 100644 --- a/configs/sft/sft_lora.yaml +++ b/configs/sft/sft_lora.yaml @@ -16,6 +16,8 @@ learning_rate: 5e-5 lr_scheduler_type: "cosine" enable_gradient_checkpointing: true warmup_ratio: 0.1 +packing: true +packing_num_proc: 4 # Batch Size train_micro_batch_size_per_gpu: 4