Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584674881
  • Loading branch information
davidsoergel authored and SeqIO committed Dec 1, 2023
1 parent 515d917 commit 8d95da8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
4 changes: 4 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
Feature = utils.Feature




@dataclasses.dataclass(frozen=True)
class ContinuousFeature(Feature):
"""A container for multi-modal output features of data providers."""
Expand Down Expand Up @@ -1458,6 +1460,8 @@ def cache_dir(self) -> Optional[str]:
os.path.join(d, utils.get_task_dir_from_name(self.name))
for d in utils.get_global_cache_dirs()
]


for cache_dir in potential_cache_dirs:
try:
if tf.io.gfile.exists(os.path.join(cache_dir, "COMPLETED")):
Expand Down
13 changes: 13 additions & 0 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,18 @@ def test_requires_caching(self):
):
task.get_dataset({"inputs": 512, "targets": 512}, use_cached=False)

def test_requires_caching_not_cached(self):
utils.set_global_cache_dirs([self.test_data_dir])

task = dataset_providers.Task(
"requires_cache",
output_features=self.DEFAULT_OUTPUT_FEATURES,
source=self.function_source,
preprocessors=[
dataset_providers.CacheDatasetPlaceholder(required=True),
preprocessors.tokenize,
],
)
# We haven't actually cached the task, so it still fails but with a
# different error.
with self.assertRaisesWithLiteralMatch(
Expand Down Expand Up @@ -488,6 +500,7 @@ def test_set_global_cache_dirs(self):
utils.set_global_cache_dirs([self.test_data_dir])
self.assertTrue(self.cached_task.cache_dir)


def test_get_dataset_cached(self):
self.verify_task_matches_fake_datasets(
"cached_task", use_cached=True, token_preprocessed=False
Expand Down
2 changes: 2 additions & 0 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def add_global_cache_dirs(global_cache_dirs):
_GLOBAL_CACHE_DIRECTORIES += global_cache_dirs




def _validate_tfds_name(name: str) -> None:
"""Validates TFDS dataset name."""
if (
Expand Down

0 comments on commit 8d95da8

Please sign in to comment.