diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py index 243f4085023..31dcb53a920 100644 --- a/src/sparseml/transformers/finetune/data/data_helpers.py +++ b/src/sparseml/transformers/finetune/data/data_helpers.py @@ -128,9 +128,12 @@ def make_dataset_splits( train_split = eval_split = predict_split = calib_split = None if do_train: - if "train" not in tokenized_datasets: + if "train" in tokenized_datasets: + train_split = tokenized_datasets["train"] + elif "train_sft" in tokenized_datasets: + train_split = tokenized_datasets["train_sft"] + else: raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] if do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") @@ -142,7 +145,11 @@ def make_dataset_splits( if do_oneshot: calib_split = tokenized_datasets.get("calibration") if calib_split is None: - if "train" not in tokenized_datasets: + if "train" in tokenized_datasets: + train_split = tokenized_datasets["train"] + elif "train_sft" in tokenized_datasets: + train_split = tokenized_datasets["train_sft"] + else: raise ValueError("--do_oneshot requires a calibration dataset") calib_split = tokenized_datasets["train"]