From b9de9096f37bbbf0dad01209f356cfea217930d6 Mon Sep 17 00:00:00 2001 From: Gabe Weisz Date: Tue, 2 Dec 2025 07:20:08 -0600 Subject: [PATCH 1/2] Plumb config.max_segments_per_seq to grain PackAndBatchOperation --- src/MaxText/input_pipeline/_hf_data_processing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index e056cd972..1dc671e44 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -192,6 +192,7 @@ def preprocessing_pipeline( use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 + max_segments_per_seq=1, ): """pipeline for preprocessing HF dataset""" @@ -301,6 +302,7 @@ def lists2array(x): grain.experimental.PackAndBatchOperation( batch_size=global_batch_size // jax.process_count(), length_struct=length_struct, + max_sequences_per_bin=max_segments_per_seq, ) ) operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) @@ -386,6 +388,7 @@ def make_hf_train_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_sequences_per_bin=config.max_segments_per_seq, ) return train_iter From 6f2a59e2643e334e4570934ee44eee10709bab44 Mon Sep 17 00:00:00 2001 From: Gabe Weisz Date: Thu, 4 Dec 2025 12:55:33 -0600 Subject: [PATCH 2/2] Add max_sequences_per_bin to make_hf_eval_iterator too --- src/MaxText/input_pipeline/_hf_data_processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index 1dc671e44..5936c88af 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -440,5 +440,6 @@ def make_hf_eval_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_sequences_per_bin=config.max_segments_per_seq, ) return eval_iter