Skip to content

Commit

Permalink
Add PreprocessTask.preprocess_postcache option to perform preprocessi…
Browse files Browse the repository at this point in the history
…ng for tasks that don't support caching.

PiperOrigin-RevId: 582875831
  • Loading branch information
SeqIO Team authored and SeqIO committed Nov 16, 2023
1 parent 44fecb3 commit 394097f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
42 changes: 32 additions & 10 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(
split: str,
*,
preprocessors_seed: Optional[int] = None,
preprocess_postcache: bool = False,
sequence_length: Optional[Mapping[str, int]] = None,
setup_fn: Callable[[], None] = lambda: None,
modules_to_import: Sequence[str] = (),
add_provenance: bool = False,
Expand All @@ -73,6 +75,11 @@ def __init__(
split: string, the split to process.
preprocessors_seed: (Optional) int, a seed for stateless random ops in
task preprocessing.
preprocess_postcache: (Optional) If True, the `preprocess_postcache()`
function is invoked on the task. Setting this flag to True ensures that
preprocessing is performed, regardless if the task supports caching.
sequence_length: (Optional) A map of feature key to maximum int length for
that feature. Used only if `preprocess_postcache` is true.
setup_fn: (Optional) callable, a function called before loading the task.
modules_to_import: (Optional) list, modules to import.
add_provenance: If True, provenance is added to each example.
Expand All @@ -85,6 +92,8 @@ def __init__(
self._task_name = task.name
self._split = split
self._preprocessors_seed = preprocessors_seed
self._preprocess_postcache = preprocess_postcache
self._sequence_length = sequence_length
self._setup_fn = setup_fn
self._modules_to_import = modules_to_import
self._add_provenance = add_provenance
Expand Down Expand Up @@ -135,16 +144,7 @@ def _emit_examples(self, shard: Tuple[int, str]):
self._preprocessors_seed or 0
)

ds = task.source.get_dataset(
split=self._split,
shard_info=seqio.ShardInfo(
index=shard_index, num_shards=len(self.shards)
),
shuffle=False,
seed=shard_preprocessors_seed,
)
ds = task.preprocess_precache(ds, seed=shard_preprocessors_seed)
ds = ds.prefetch(tf.data.AUTOTUNE)
ds = self._get_dataset(task, shard_index, shard_preprocessors_seed)

def _add_provenance(
index_within_shard: int, ex: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -167,6 +167,28 @@ def _add_provenance(
logging.info("Example [%d] = %s", i, ex)
yield ex

def _get_dataset(
self, task: seqio.Task, shard_index: int, shard_preprocessors_seed: int
) -> tf.data.Dataset:
"""Gets and preprocesses the dataset for the provided task."""
ds = task.source.get_dataset(
split=self._split,
shard_info=seqio.ShardInfo(
index=shard_index, num_shards=len(self.shards)
),
shuffle=False,
seed=shard_preprocessors_seed,
)
ds = task.preprocess_precache(ds, seed=shard_preprocessors_seed)
if self._preprocess_postcache:
ds = task.preprocess_postcache(
ds,
sequence_length=self._sequence_length,
seed=shard_preprocessors_seed,
)
ds = ds.prefetch(tf.data.AUTOTUNE)
return ds

def expand(self, pipeline):
return (
pipeline
Expand Down
25 changes: 25 additions & 0 deletions seqio/beam_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,31 @@ def test_preprocess_task_with_setup_fn(self):
self.assertLen(counters, 1)
self.assertGreater(counters[0].committed, 0)

def test_preprocess_task_with_preprocess_postcache(self):
def preprocessor(dataset, sequence_length=None):
self.assertEqual(sequence_length, 123)
beam.metrics.Metrics.counter("test", "preprocessor_called").inc()
return dataset

self.add_task(
"test_task", source=self.tfds_source, preprocessors=[preprocessor]
)
with TestPipeline() as p:
pcoll = p | beam_utils.PreprocessTask(
task=seqio.get_mixture_or_task("test_task"),
split="train",
preprocess_postcache=True,
sequence_length=123,
)
result = p.run()
util.assert_that(pcoll, util.is_not_empty())

counters = result.metrics().query(
beam.metrics.MetricsFilter().with_name("preprocessor_called")
)["counters"]
self.assertLen(counters, 1)
self.assertGreater(counters[0].committed, 0)

def test_write_example_tf_record(self):
output_path = os.path.join(self.test_data_dir, "output.tfrecord")
example = {
Expand Down

0 comments on commit 394097f

Please sign in to comment.