Skip to content

Commit

Permalink
Fix saved model embedding shape (adds missing channel dimension), and…
Browse files Browse the repository at this point in the history
… add some tests.

PiperOrigin-RevId: 592300844
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 19, 2023
1 parent 89d4a7b commit ac48ce0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
4 changes: 3 additions & 1 deletion chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def frame_audio(
def normalize_audio(
self,
framed_audio: np.ndarray,
target_peak: float,
target_peak: float | None,
) -> np.ndarray:
"""Normalizes audio with shape [..., T] to match the target_peak value."""
if target_peak is None:
return framed_audio
framed_audio = framed_audio.copy()
framed_audio -= np.mean(framed_audio, axis=-1, keepdims=True)
peak_norm = np.max(np.abs(framed_audio), axis=-1, keepdims=True)
Expand Down
21 changes: 15 additions & 6 deletions chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,15 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
framed_audio = self.frame_audio(
audio_array, self.window_size_s, self.hop_size_s
)
if self.target_peak is not None:
framed_audio = self.normalize_audio(framed_audio, self.target_peak)
framed_audio = self.normalize_audio(framed_audio, self.target_peak)

all_logits, all_embeddings = self.model.infer_tf(framed_audio[:1])
for window in framed_audio[1:]:
logits, embeddings = self.model.infer_tf(window[np.newaxis, :])
all_logits = np.concatenate([all_logits, logits], axis=0)
all_embeddings = np.concatenate([all_embeddings, embeddings], axis=0)

# Add channel dimension.
all_embeddings = all_embeddings[:, np.newaxis, :]

return interface.InferenceOutputs(
Expand All @@ -351,17 +353,24 @@ def batch_embed(
framed_audio = self.frame_audio(
audio_batch, self.window_size_s, self.hop_size_s
)
if self.target_peak is not None:
framed_audio = self.normalize_audio(framed_audio, self.target_peak)
framed_audio = self.normalize_audio(framed_audio, self.target_peak)

rebatched_audio = framed_audio.reshape([-1, framed_audio.shape[-1]])
logits, embeddings = self.model.infer_tf(rebatched_audio)
logits = np.reshape(logits, framed_audio.shape[:2] + (logits.shape[-1],))
# Unbatch and add channel dimension.
embeddings = np.reshape(
embeddings, framed_audio.shape[:2] + (embeddings.shape[-1],)
embeddings,
framed_audio.shape[:2]
+ (
1,
embeddings.shape[-1],
),
)

return interface.InferenceOutputs(embeddings, {'label': logits}, None)
return interface.InferenceOutputs(
embeddings, {'label': logits}, None, batched=True
)


@dataclasses.dataclass
Expand Down
53 changes: 53 additions & 0 deletions chirp/tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,59 @@ def test_beam_pipeline(self):

print(metrics)

@parameterized.product(
batchable=(True, False),
)
def test_taxonomy_model_tf(self, batchable):
class FakeModelFn:
output_depths = (3, 256)

def infer_tf(self, audio_array):
outputs = [
np.zeros([audio_array.shape[0], depth], dtype=np.float32)
for depth in self.output_depths
]
return outputs

class_list = namespace.ClassList('fake', ['alpha', 'beta', 'delta'])
wrapped_model = models.TaxonomyModelTF(
sample_rate=32000,
model_path='/dev/null',
window_size_s=5.0,
hop_size_s=5.0,
model=FakeModelFn(),
class_list=class_list,
batchable=batchable,
)

# Check that a single frame of audio is handled properly.
outputs = wrapped_model.embed(np.zeros([5 * 32000], dtype=np.float32))
self.assertFalse(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [1, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [1, 3])

# Check that multi-frame audio is handled properly.
outputs = wrapped_model.embed(np.zeros([20 * 32000], dtype=np.float32))
self.assertFalse(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [4, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [4, 3])

# Check that a batch of single frame of audio is handled properly.
outputs = wrapped_model.batch_embed(
np.zeros([10, 5 * 32000], dtype=np.float32)
)
self.assertTrue(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [10, 1, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [10, 1, 3])

# Check that a batch of multi-frame audio is handled properly.
outputs = wrapped_model.batch_embed(
np.zeros([2, 20 * 32000], dtype=np.float32)
)
self.assertTrue(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [2, 4, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [2, 4, 3])


if __name__ == '__main__':
absltest.main()

0 comments on commit ac48ce0

Please sign in to comment.