-
Notifications
You must be signed in to change notification settings - Fork 13
支持packing算法,提高训练速度 #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JasonCZH4
wants to merge
6
commits into
wxhcore:main
Choose a base branch
from
JasonCZH4:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6f54f60
支持packing算法,提高训练速度
JasonCZH4 5ea95d3
Remove commented-out SFTDataset class
JasonCZH4 470021b
提升代码健壮性,HfDeepSpeedConfig包含了zero.Init的功能,去除重叠部分
wxhcore 17d7cb6
修复原来packing的bug,加入test_packing
JasonCZH4 0485a61
添加test_packing运行代码,忽略png
JasonCZH4 1da4cea
修改packing相关yaml配置文件
JasonCZH4 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,3 +65,4 @@ experiments/*.md | |
| # DeepSpeed | ||
| deepspeed_logs/ | ||
|
|
||
| *.png | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,34 @@ | ||
| from .datasets import SFTDataset,PretrainDataset,DataCollator,DPODataset,DPOCollator | ||
| from .datasets import ( | ||
| SFTDataset, PretrainDataset, DataCollator, DPODataset, DPOCollator, | ||
| PackingDataCollator, | ||
| ) | ||
| from .preprocess import load_pretrain_data,load_sft_data,load_dpo_data | ||
| from .data_formatter import DataFormatter | ||
| from .dataset_utils import ( | ||
| show_sample, | ||
| get_padding_value, | ||
| calculate_matched_group, | ||
| split_list, | ||
| is_master, | ||
| is_distributed, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "SFTDataset", | ||
| "PretrainDataset", | ||
| "SFTDataset", | ||
| "PretrainDataset", | ||
| "DataCollator", | ||
| "load_pretrain_data", | ||
| "load_sft_data", | ||
| "load_dpo_data", | ||
| "DataFormatter", | ||
| "DPODataset", | ||
| "DPOCollator" | ||
| "DPOCollator", | ||
| "PackingDataCollator", | ||
| # Utility functions | ||
| "show_sample", | ||
| "get_padding_value", | ||
| "calculate_matched_group", | ||
| "split_list", | ||
| "is_master", | ||
| "is_distributed", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| """Utility functions for dataset processing.""" | ||
|
|
||
| import io | ||
| from typing import List, Tuple | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from rich.console import Console | ||
| from rich.table import Table | ||
| from rich.text import Text | ||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| def show_sample( | ||
| input_ids, | ||
| labels, | ||
| tokenizer, | ||
| title="Input and Labels", | ||
| left_column="Input IDs", | ||
| right_column="Labels" | ||
| ): | ||
| """Display a sample with input_ids and labels in a formatted table.""" | ||
| input_ids = input_ids.tolist() | ||
| labels = labels.tolist() | ||
|
|
||
| valid_labels_list = [token_id for token_id in labels if token_id != -100] | ||
| decoded_input = tokenizer.decode(input_ids) | ||
| decoded_labels = tokenizer.decode(valid_labels_list) | ||
|
|
||
| table = Table(show_header=True, show_lines=True, title=title) | ||
| table.add_column(left_column, overflow="fold") | ||
| table.add_column(right_column, overflow="fold") | ||
|
|
||
| wrapped_input = Text(decoded_input, no_wrap=False, overflow="fold") | ||
| wrapped_labels = Text(decoded_labels, no_wrap=False, overflow="fold") | ||
|
|
||
| table.add_row(str(input_ids), str(labels)) | ||
| table.add_row(wrapped_input, wrapped_labels) | ||
|
|
||
| with io.StringIO() as buf: | ||
| console = Console(file=buf, force_terminal=False) | ||
| console.print(table) | ||
| output = buf.getvalue() | ||
|
|
||
| tqdm.write(output.rstrip()) | ||
|
|
||
|
|
||
| def get_padding_value(tokenizer): | ||
| """Get the padding token id from tokenizer. | ||
|
|
||
| If pad_token_id is not set, use eos_token_id as fallback. | ||
| """ | ||
| if tokenizer.pad_token_id is not None: | ||
| return tokenizer.pad_token_id | ||
|
|
||
| eos = tokenizer.eos_token_id | ||
| return eos[0] if isinstance(eos, list) else eos | ||
|
|
||
|
|
||
| def calculate_matched_group(sequences: List[Tuple[int, int]], packing_length: int, is_finished: bool = True): | ||
| """Bin-packing via First Fit Decreasing (https://arxiv.org/pdf/2404.10830). | ||
|
|
||
| Args: | ||
| sequences: List of (index, length) tuples. | ||
| packing_length: Maximum length for each pack. | ||
| is_finished: Whether this is the last batch. | ||
|
|
||
| Returns: | ||
| Tuple of (packed_sequences, remaining_sequences). | ||
| packed_sequences is a list of lists, each containing (index, length) tuples. | ||
| """ | ||
| if len(sequences) == 0: | ||
| return [], [] | ||
| import binpacking | ||
|
|
||
| # sequences 是 [(index, length), ...] 列表 | ||
| # weight_pos=1 表示长度在元组第二个位置 | ||
| # 将一组物品分配到多个容量固定的箱子(bins)中,使得每个箱子的总容量不超过指定的最大值。 | ||
| sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1) | ||
|
|
||
| # sequences 是列表的列表,每个子列表包含多个 (index, length) 元组 | ||
| # 如果不是最后一批,保留最后一个不完整组用于下一批 | ||
| if sequences and not is_finished: | ||
| sequences, ret_sequences = sequences[:-1], sequences[-1] | ||
| else: | ||
| ret_sequences = [] | ||
| return sequences, ret_sequences | ||
|
|
||
|
|
||
| def split_list(lst: list, n: int) -> List[list]: | ||
| """Split a list into n sublists as evenly as possible. | ||
|
|
||
| Args: | ||
| lst: The list to split. | ||
| n: Number of parts to split into. | ||
|
|
||
| Returns: | ||
| List of n sublists. | ||
| """ | ||
| # 划分列表为n个子列表,对应n个子进程处理 | ||
| k, m = divmod(len(lst), n) | ||
| return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)] | ||
|
|
||
|
|
||
| def is_master() -> bool: | ||
| """Check if current process is the master process in distributed training.""" | ||
| if dist.is_available() and dist.is_initialized(): | ||
| return dist.get_rank() == 0 | ||
| return True | ||
|
|
||
|
|
||
| def is_distributed() -> bool: | ||
| """Check if running in distributed training mode.""" | ||
| return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
type=boolis used, it doesn't work as expected because passing any non-empty string will be interpreted asTrue.