diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 04c719d8..ac9701d2 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -1665,6 +1665,8 @@ def get_dataset( # Shuffle before mixing since preprocessor can output multiple # (correlated) examples per input. ds = ds.shuffle(shuffle_buffer_size, seed=seed) + + return ds.prefetch(tf.data.experimental.AUTOTUNE) def _get_cached_source( diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index 27bf3b7f..6ef62981 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -945,6 +945,7 @@ def test_plaintext_to_pretokenized_rename(self): ), ) + def test_list_shards(self): def _get_formatted_shards_list(task_name, split): shards = dataset_providers.get_mixture_or_task(