Skip to content

Commit

Permalink
Allow storing embeddings in float16, which speeds up IO considerably.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596086346
  • Loading branch information
sdenton4 authored and copybara-github committed Jan 5, 2024
1 parent c954cc0 commit 043ff73
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 18 deletions.
22 changes: 16 additions & 6 deletions chirp/inference/embed_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def create_source_infos(


def get_existing_source_ids(
output_dir: epath.Path, file_pattern: str
output_dir: epath.Path, file_pattern: str, tensor_dtype: str = 'float32'
) -> Set[SourceId]:
"""Return existing SourceInfos from the matching output dir and pattern."""
existing_source_ids = set([])
if not output_dir.exists:
return existing_source_ids
filenames = [fn for fn in output_dir.glob(file_pattern)]
dataset = tf.data.TFRecordDataset(filenames)
parser = tf_examples.get_example_parser()
parser = tf_examples.get_example_parser(tensor_dtype=tensor_dtype)
dataset = dataset.map(parser)
for e in dataset.as_numpy_iterator():
existing_source_ids.add(
Expand Down Expand Up @@ -167,6 +167,7 @@ def __init__(
embedding_model: interface.EmbeddingModel | None = None,
target_sample_rate: int = -2,
logits_head_config: config_dict.ConfigDict | None = None,
tensor_dtype: str = 'float32',
):
"""Initialize the embedding DoFn.
Expand All @@ -191,6 +192,8 @@ def __init__(
resample to a fixed rate.
logits_head_config: Optional configuration for a secondary
interface.LogitsOutputHead classifying the model embeddings.
tensor_dtype: Dtype to use for storing tensors (embeddings, logits, or
audio). Default to float32, but float16 approximately halves file size.
"""
self.model_key = model_key
self.model_config = model_config
Expand All @@ -205,6 +208,7 @@ def __init__(
self.target_sample_rate = target_sample_rate
self.logits_head_config = logits_head_config
self.logits_head = None
self.tensor_dtype = tensor_dtype

def setup(self):
if self.embedding_model is None:
Expand Down Expand Up @@ -273,6 +277,7 @@ def audio_to_example(
write_separated_audio=self.write_separated_audio,
write_embeddings=self.write_embeddings,
write_logits=write_logits,
tensor_dtype=self.tensor_dtype,
)
return example

Expand Down Expand Up @@ -343,12 +348,17 @@ def process(self, source_info: SourceInfo, crop_s: float = -1.0):
return [example]


def get_config(config_key: str):
def get_config(config_key: str, shard_idx: str = '') -> config_dict.ConfigDict:
"""Get a config given its keyed name."""
module_key = '..{}'.format(config_key)
config = importlib.import_module(
module_key, INFERENCE_CONFIGS_PKG
).get_config()
if shard_idx:
config = importlib.import_module(
module_key, INFERENCE_CONFIGS_PKG
).get_config(shard_idx)
else:
config = importlib.import_module(
module_key, INFERENCE_CONFIGS_PKG
).get_config()

logging.info('Loaded config %s', config_key)
logging.info('Config output location : %s', config.output_dir)
Expand Down
27 changes: 18 additions & 9 deletions chirp/inference/tf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def get_feature_description(logit_names: Sequence[str] | None = None):
return feature_description


def get_example_parser(logit_names: Sequence[str] | None = None):
def get_example_parser(
logit_names: Sequence[str] | None = None, tensor_dtype: str = 'float32'
):
"""Create a parser for decoding inference library TFExamples."""
features = get_feature_description(logit_names=logit_names)

Expand All @@ -93,9 +95,9 @@ def _parser(ex):
# both conditional branches. So we use an empty tensor when no
# data is present to parse.
if ex[key] != tf.constant(b'', dtype=tf.string):
ex[key] = tf.io.parse_tensor(ex[key], tf.float32)
ex[key] = tf.io.parse_tensor(ex[key], out_type=tensor_dtype)
else:
ex[key] = tf.zeros_like([], dtype=tf.float32)
ex[key] = tf.zeros_like([], dtype=tensor_dtype)
return ex

return _parser
Expand All @@ -106,6 +108,7 @@ def create_embeddings_dataset(
file_glob: str = '*',
prefetch: int = 128,
logit_names: Sequence[str] | None = None,
tensor_dtype: str = 'float32',
):
"""Create a TF Dataset of the embeddings."""
embeddings_dir = epath.Path(embeddings_dir)
Expand All @@ -114,13 +117,16 @@ def create_embeddings_dataset(
embeddings_files, num_parallel_reads=tf.data.AUTOTUNE
)

parser = get_example_parser(logit_names=logit_names)
parser = get_example_parser(
logit_names=logit_names, tensor_dtype=tensor_dtype
)
ds = ds.map(parser, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.prefetch(prefetch)
return ds


def serialize_tensor(tensor: np.ndarray) -> np.ndarray:
def serialize_tensor(tensor: np.ndarray, tensor_dtype: str) -> np.ndarray:
tensor = tf.cast(tensor, tensor_dtype)
serialized = tf.io.serialize_tensor(tensor)
return serialized.numpy()

Expand All @@ -134,6 +140,7 @@ def model_outputs_to_tf_example(
write_logits: bool | Sequence[str],
write_separated_audio: bool,
write_raw_audio: bool,
tensor_dtype: str = 'float32',
) -> tf.train.Example:
"""Create a TFExample from InferenceOutputs."""
feature = {
Expand All @@ -142,7 +149,7 @@ def model_outputs_to_tf_example(
}
if write_embeddings and model_outputs.embeddings is not None:
feature[EMBEDDING] = bytes_feature(
serialize_tensor(model_outputs.embeddings)
serialize_tensor(model_outputs.embeddings, tensor_dtype)
)
feature[EMBEDDING_SHAPE] = (int_feature(model_outputs.embeddings.shape),)

Expand All @@ -154,19 +161,21 @@ def model_outputs_to_tf_example(
logit_keys = tuple(k for k in logit_keys if k in write_logits)
for logits_key in logit_keys:
logits = model_outputs.logits[logits_key]
feature[logits_key] = bytes_feature(serialize_tensor(logits))
feature[logits_key] = bytes_feature(
serialize_tensor(logits, tensor_dtype)
)
feature[logits_key + '_shape'] = int_feature(logits.shape)

if write_separated_audio and model_outputs.separated_audio is not None:
feature[SEPARATED_AUDIO] = bytes_feature(
serialize_tensor(model_outputs.separated_audio)
serialize_tensor(model_outputs.separated_audio, tensor_dtype)
)
feature[SEPARATED_AUDIO_SHAPE] = int_feature(
model_outputs.separated_audio.shape
)
if write_raw_audio:
feature[RAW_AUDIO] = bytes_feature(
serialize_tensor(tf.constant(audio, dtype=tf.float32))
serialize_tensor(tf.constant(audio, dtype=tf.float32), tensor_dtype)
)
feature[RAW_AUDIO_SHAPE] = int_feature(audio.shape)
ex = tf.train.Example(features=tf.train.Features(feature=feature))
Expand Down
9 changes: 8 additions & 1 deletion chirp/projects/bootstrap/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def create_embeddings_dataset(self):
if self.embeddings_dataset:
return self.embeddings_dataset
ds = tf_examples.create_embeddings_dataset(
self.config.embeddings_path, 'embeddings-*'
self.config.embeddings_path,
'embeddings-*',
tensor_dtype=self.config.tensor_dtype,
)
self.embeddings_dataset = ds
return ds
Expand Down Expand Up @@ -84,6 +86,9 @@ class BootstrapConfig:
# Annotations info.
annotated_path: str

# Tensor dtype in embeddings.
tensor_dtype: str

# The following are populated automatically from the embedding config.
embedding_hop_size_s: float | None = None
file_id_depth: int | None = None
Expand All @@ -98,6 +103,7 @@ def load_from_embedding_config(
"""Instantiate from a configuration written alongside embeddings."""
embedding_config = embed_lib.load_embedding_config(embeddings_path)
embed_fn_config = embedding_config.embed_fn_config
tensor_dtype = embed_fn_config.get('tensor_dtype', 'float32')

# Extract the embedding model config from the embedding_config.
if embed_fn_config.model_key == 'separate_embed_model':
Expand All @@ -115,4 +121,5 @@ def load_from_embedding_config(
embedding_hop_size_s=model_config.hop_size_s,
file_id_depth=embed_fn_config.file_id_depth,
audio_globs=embedding_config.source_file_patterns,
tensor_dtype=tensor_dtype,
)
4 changes: 3 additions & 1 deletion chirp/projects/multicluster/data_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def read_embedded_dataset(
embeddings_path: str,
time_pooling: str,
exclude_classes: Sequence[str] = (),
tensor_dtype: str = 'float32',
):
"""Read pre-saved embeddings to memory from storage.
Expand All @@ -420,6 +421,7 @@ def read_embedded_dataset(
embeddings_path: Location of the existing embeddings as TFRecordDataset.
time_pooling: Method of time pooling.
exclude_classes: List of classes to exclude.
tensor_dtype: Tensor dtype used in the embeddings tfrecords.
Returns:
Ordered labels and a Dict contianing the entire embedded dataset.
Expand All @@ -428,7 +430,7 @@ def read_embedded_dataset(
output_dir = epath.Path(embeddings_path)
fns = [fn for fn in output_dir.glob('embeddings-*')]
ds = tf.data.TFRecordDataset(fns)
parser = tf_examples.get_example_parser()
parser = tf_examples.get_example_parser(tensor_dtype=tensor_dtype)
ds = ds.map(parser)

# Loading the lables assuming a folder-of-folder structure
Expand Down
6 changes: 5 additions & 1 deletion chirp/tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class InferenceTest(parameterized.TestCase):
write_logits=(True, False),
write_separated_audio=(True, False),
write_raw_audio=(True, False),
tensor_dtype=('float32', 'float16'),
)
def test_embed_fn(
self,
Expand All @@ -67,6 +68,7 @@ def test_embed_fn(
write_logits,
write_raw_audio,
write_separated_audio,
tensor_dtype,
):
model_kwargs = {
'sample_rate': 16000,
Expand All @@ -83,6 +85,7 @@ def test_embed_fn(
model_key='placeholder_model',
model_config=model_kwargs,
file_id_depth=0,
tensor_dtype=tensor_dtype,
)
embed_fn.setup()
self.assertIsNotNone(embed_fn.embedding_model)
Expand All @@ -98,7 +101,8 @@ def test_embed_fn(
serialized = example.SerializeToString()

parser = tf_examples.get_example_parser(
logit_names=['label', 'other_label']
logit_names=['label', 'other_label'],
tensor_dtype=tensor_dtype,
)
got_example = parser(serialized)
self.assertIsNotNone(got_example)
Expand Down

0 comments on commit 043ff73

Please sign in to comment.