Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 769abb3

Browse files
authored
Updates to enable ultrachat200k
Ultrachat200k has 2 splits for training, one for sft and another for dpo. As a result it doesn't have a "train" split per se. This PR allows for a train_sft alternative.
1 parent f784980 commit 769abb3

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/sparseml/transformers/finetune/data/data_helpers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ def make_dataset_splits(
128128
train_split = eval_split = predict_split = calib_split = None
129129

130130
if do_train:
131-
if "train" not in tokenized_datasets:
131+
if "train" in tokenized_datasets:
132+
train_split = tokenized_datasets["train"]
133+
elif "train_sft" in tokenized_datasets:
134+
train_split = tokenized_datasets["train_sft"]
135+
else:
132136
raise ValueError("--do_train requires a train dataset")
133-
train_split = tokenized_datasets["train"]
134137
if do_eval:
135138
if "validation" not in tokenized_datasets:
136139
raise ValueError("--do_eval requires a validation dataset")
@@ -142,7 +145,11 @@ def make_dataset_splits(
142145
if do_oneshot:
143146
calib_split = tokenized_datasets.get("calibration")
144147
if calib_split is None:
145-
if "train" not in tokenized_datasets:
148+
if "train" in tokenized_datasets:
149+
train_split = tokenized_datasets["train"]
150+
elif "train_sft" in tokenized_datasets:
151+
train_split = tokenized_datasets["train_sft"]
152+
else:
146153
raise ValueError("--do_oneshot requires a calibration dataset")
147154
calib_split = tokenized_datasets["train"]
148155

0 commit comments

Comments
 (0)