Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Zimtohrli python version easier to jit. #128

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions python/audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Generic classes containing signals."""

import dataclasses
import functools
from typing import Union
import jax
import jax.numpy as jnp
Expand All @@ -31,15 +32,17 @@
]


@jax.tree_util.register_pytree_node_class
@functools.partial(jax.tree_util.register_dataclass,
data_fields=['samples'],
meta_fields=['sample_rate'])
@dataclasses.dataclass(frozen=True)
class Signal:
"""Class defining a digital signal at a given sample rate.

Attributes:
sample_rate: The number of samples per second in the signal.
samples: A (num_samples,)-shaped array with the signal samples, in the range
-1 to 1.
samples: A (num_samples,)-shaped array with the signal samples, in the
range -1 to 1.
"""

sample_rate: Numerical
Expand All @@ -53,15 +56,17 @@ def tree_unflatten(cls, _, children):
return cls(*children)


@jax.tree_util.register_pytree_node_class
@functools.partial(jax.tree_util.register_dataclass,
data_fields=['samples'],
meta_fields=['sample_rate', 'freqs'])
@dataclasses.dataclass(frozen=True)
class Channels:
"""Class defining a set of digital signals being related channels.

Attributes:
sample_rate: The number of samples per second in the signal.
samples: A (num_channels, num_samples)-shaped array with the samples of the
channel signals, in the range -1 to 1.
samples: A (num_channels, num_samples)-shaped array with the samples of
the channel signals, in the range -1 to 1.
freqs: A (num_channels, 2)-shaped array with the low and high pass
frequencies of the channels.
"""
Expand Down Expand Up @@ -110,13 +115,13 @@ def to_db(
"""Returns the channels in dB relative full_scale_sine_db.

Make sure to only call this on Channels that are the result of calling
Channels.energy, since otherwise the dB conversion will get negative numbers
which will cause nans.
Channels.energy, since otherwise the dB conversion will get negative
numbers which will cause nans.

Args:
full_scale_sine_db: The reference dB SPL of a full scale sine.
db_epsilon: The epsilon to add to the energy before converting to dB to
avoid log of zero.
db_epsilon: The epsilon to add to the energy before converting to dB
to avoid log of zero.

Returns:
The energy in the channels in dB, downsample to the out_sample_rate.
Expand Down
8 changes: 4 additions & 4 deletions python/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import audio_signal


@jax.tree_util.register_pytree_node_class
@jax.tree_util.register_static
@dataclasses.dataclass
class Cam:
"""Handles converting between Hz and Cam.
Expand Down Expand Up @@ -80,15 +80,15 @@ def __post_init__(self):
- start_cam
)
stop_cam = self.cam_from_hz(self.maximum_channel_upper_bound)
self._hz_freqs = self.hz_from_cam(jnp.arange(start_cam, stop_cam, cam_step))
self._hz_freqs = self.hz_from_cam(np.arange(start_cam, stop_cam, cam_step))

def hz_from_cam(self, cam: audio_signal.Numerical) -> audio_signal.Numerical:
"""Returns the Hz frequency for the provided Cam frequency."""
return (10 ** (cam / self.erbs_scale_1) - self.erbs_offset) / self.erbs_scale_2

def cam_from_hz(self, hz: audio_signal.Numerical) -> audio_signal.Numerical:
"""Returns the Cam frequency for the provided Hz frequency."""
return self.erbs_scale_1 * jnp.log10(self.erbs_offset + self.erbs_scale_2 * hz)
return self.erbs_scale_1 * np.log10(self.erbs_offset + self.erbs_scale_2 * hz)

def channel_filter(self, sig: audio_signal.Signal) -> audio_signal.Channels:
"""Returns the signal filtered through a filter bank."""
Expand All @@ -109,7 +109,7 @@ def channel_filter(self, sig: audio_signal.Signal) -> audio_signal.Channels:
)
)

freqs_ary = jnp.asarray(freqs)
freqs_ary = np.asarray(freqs)

return audio_signal.Channels(
sample_rate=sig.sample_rate,
Expand Down
3 changes: 2 additions & 1 deletion python/cam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_hz_from_cam(self, cam, hz):
dict(hz=10000, cam=35.316578),
)
def test_cam_from_hz(self, hz, cam):
self.assertAlmostEqual(self.cam.cam_from_hz(hz), cam)
self.assertAlmostEqual(self.cam.cam_from_hz(hz), cam,
delta=1e-5)

def test_channel_filter(self):
fs = 48000
Expand Down
7 changes: 4 additions & 3 deletions python/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
"""Handles masking of sounds."""

import dataclasses
import functools
from typing import Optional
import jax
import jax.numpy as jnp
import numpy as np
import cam
import audio_signal


@jax.tree_util.register_pytree_node_class
@jax.tree_util.register_static
@dataclasses.dataclass
class Masking:
"""Handles masking of sounds.
Expand Down Expand Up @@ -131,7 +133,6 @@ def __post_init__(self):
jax.jit(jax.vmap(full_masking_multi_maskers_multi_probes, (None, 1), 1)),
)

@jax.jit
def non_masked_energy(
self, energy_channels_db: audio_signal.Channels
) -> audio_signal.Channels:
Expand All @@ -144,7 +145,7 @@ def non_masked_energy(
energy_channels after having removed masked components.
"""
cams = self.cam_model.cam_from_hz(energy_channels_db.freqs[:, 0])
cam_delta = cams[jnp.newaxis, ...] - cams[..., jnp.newaxis]
cam_delta = cams[np.newaxis, ...] - cams[..., np.newaxis]
max_full_masking_db = jnp.max(
self.full_masking_of_channels(cam_delta, energy_channels_db.samples), axis=0
).T
Expand Down
5 changes: 3 additions & 2 deletions python/zimtohrli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import jax.numpy as jnp
import numpy as np
import cam
import loudness
import masking
Expand Down Expand Up @@ -45,8 +46,8 @@ class Zimtohrli:
def spectrogram(
self,
signal: audio_signal.Signal,
full_scale_sine_db: jnp.ndarray = jnp.asarray(90),
db_epsilon: jnp.ndarray = jnp.asarray(1e-9),
full_scale_sine_db: jnp.ndarray = np.asarray(90),
db_epsilon: jnp.ndarray = np.asarray(1e-9),
) -> audio_signal.Channels:
"""Returns a perceptual spectrogram of the signal.

Expand Down
22 changes: 18 additions & 4 deletions python/zimtohrli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for google3.third_party.zimtohrli.python.zimtohrli."""

import unittest
import jax
import numpy as np
import audio_signal
import zimtohrli
Expand All @@ -29,12 +30,25 @@ def test_zimtohrli_spectrogram_and_distance(self):
signal_b[0:1] = 0.9
sound_a = audio_signal.Signal(sample_rate=sample_rate, samples=signal_a)
sound_b = audio_signal.Signal(sample_rate=sample_rate, samples=signal_b)
z = zimtohrli.Zimtohrli()
spectrogram_a = z.spectrogram(sound_a)
spectrogram_b = z.spectrogram(sound_b)
distance = z.distance(spectrogram_a, spectrogram_b)

def compute_distance(s_a, s_b):
z = zimtohrli.Zimtohrli()
spectrogram_a = z.spectrogram(s_a)
spectrogram_b = z.spectrogram(s_b)
distance = z.distance(spectrogram_a, spectrogram_b)
return distance

# Run it without jit
distance = compute_distance(sound_a, sound_b)
self.assertGreater(distance, 0)

# Run it under jit
jit_compute_distance = jax.jit(compute_distance)
jit_distance = jit_compute_distance(sound_a, sound_b)
self.assertGreater(jit_distance, 0)

self.assertAlmostEqual(distance, jit_distance, delta=1e-5)


if __name__ == "__main__":
unittest.main()
Loading