Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a streaming spectrogram function #105

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions cpp/zimt/goohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,20 @@ void FreeZimtohrli(Zimtohrli zimtohrli) {
delete static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
}

Analysis Analyze(Zimtohrli zimtohrli, float* data, int size) {
float Distance(Zimtohrli zimtohrli, float* data_a, int size_a, float* data_b,
int size_b) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
hwy::AlignedNDArray<float, 1> signal({static_cast<size_t>(size)});
hwy::CopyBytes(data, signal.data(), size * sizeof(float));
hwy::AlignedNDArray<float, 2> channels(
{signal.shape()[0], z->cam_filterbank->filter.Size()});
zimtohrli::Analysis analysis = z->Analyze(signal[{}], channels);
return new zimtohrli::Analysis{
.energy_channels_db = std::move(analysis.energy_channels_db),
.partial_energy_channels_db =
std::move(analysis.partial_energy_channels_db),
.spectrogram = std::move(analysis.spectrogram)};
}

void FreeAnalysis(Analysis a) { delete static_cast<zimtohrli::Analysis*>(a); }

float AnalysisDistance(Zimtohrli zimtohrli, Analysis a, Analysis b) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
zimtohrli::Analysis* analysis_a = static_cast<zimtohrli::Analysis*>(a);
zimtohrli::Analysis* analysis_b = static_cast<zimtohrli::Analysis*>(b);
return z->Distance(false, analysis_a->spectrogram, analysis_b->spectrogram)
.value;
hwy::AlignedNDArray<float, 1> signal_a({static_cast<size_t>(size_a)});
hwy::CopyBytes(data_a, signal_a.data(), size_a * sizeof(float));
hwy::AlignedNDArray<float, 1> signal_b({static_cast<size_t>(size_b)});
hwy::CopyBytes(data_b, signal_b.data(), size_b * sizeof(float));
const hwy::AlignedNDArray<float, 2> spectrogram_a =
z->StreamingSpectrogram(signal_a[{}]);
const hwy::AlignedNDArray<float, 2> spectrogram_b =
z->StreamingSpectrogram(signal_b[{}]);
const zimtohrli::Distance distance =
z->Distance(false, spectrogram_a, spectrogram_b);
return distance.value;
}

ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) {
Expand Down
140 changes: 23 additions & 117 deletions cpp/zimt/pyohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,6 @@

namespace {

struct AnalysisObject {
// clang-format off
PyObject_HEAD
zimtohrli::Analysis *analysis;
// clang-format on
};

void Analysis_dealloc(AnalysisObject* self) {
if (self->analysis) {
delete self->analysis;
self->analysis = nullptr;
}
Py_TYPE(self)->tp_free((PyObject*)self);
}

PyTypeObject AnalysisType = {
// clang-format off
.ob_base = PyVarObject_HEAD_INIT(nullptr, 0)
.tp_name = "pyohrli.Analysis",
// clang-format on
.tp_basicsize = sizeof(AnalysisObject),
.tp_itemsize = 0,
.tp_dealloc = (destructor)Analysis_dealloc,
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_doc = PyDoc_STR("Python wrapper around C++ zimtohrli::Analysis."),
.tp_new = PyType_GenericNew,
};

struct PyohrliObject {
// clang-format off
PyObject_HEAD
Expand Down Expand Up @@ -99,15 +71,20 @@ struct BufferDeleter {
void operator()(Py_buffer* buffer) const { PyBuffer_Release(buffer); }
};

// Plain C++ function to analyze a Python buffer object using Zimtohrli.
PyObject* BadArgument(const std::string& message) {
PyErr_SetString(PyExc_TypeError, message.c_str());
return nullptr;
}

// Plain C++ function to copy a Python buffer object to a hwy::AlignedNDArray.
//
// Calls to Analyze never need to be cleaned up (with e.g. delete or DECREF)
// Calls to CopyBuffer never need to be cleaned up (with e.g. delete or DECREF)
// afterwards.
//
// If the return value is std::nullopt that means a Python error is set and the
// current operation should be terminated ASAP.
std::optional<zimtohrli::Analysis> Analyze(
const zimtohrli::Zimtohrli& zimtohrli, PyObject* buffer_object) {
std::optional<hwy::AlignedNDArray<float, 1>> CopyBuffer(
PyObject* buffer_object) {
Py_buffer buffer_view;
if (PyObject_GetBuffer(buffer_object, &buffer_view, PyBUF_C_CONTIGUOUS)) {
PyErr_SetString(PyExc_TypeError, "object is not buffer");
Expand All @@ -124,96 +101,34 @@ std::optional<zimtohrli::Analysis> Analyze(
}
hwy::AlignedNDArray<float, 1> signal_array({buffer_view.len / sizeof(float)});
hwy::CopyBytes(buffer_view.buf, signal_array.data(), buffer_view.len);
hwy::AlignedNDArray<float, 2> channels(
{signal_array.size(), zimtohrli.cam_filterbank->filter.Size()});
return std::optional<zimtohrli::Analysis>{
zimtohrli.Analyze(signal_array[{}], channels)};
}

PyObject* BadArgument(const std::string& message) {
PyErr_SetString(PyExc_TypeError, message.c_str());
return nullptr;
}

PyObject* Pyohrli_analyze(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 1) {
return BadArgument("not exactly 1 argument provided");
}
std::optional<zimtohrli::Analysis> analysis =
Analyze(*self->zimtohrli, args[0]);
if (!analysis.has_value()) {
return nullptr;
}
AnalysisObject* result = PyObject_New(AnalysisObject, &AnalysisType);
if (result == nullptr) {
return nullptr;
}
try {
result->analysis = new zimtohrli::Analysis{
.energy_channels_db = std::move(analysis->energy_channels_db),
.partial_energy_channels_db =
std::move(analysis->partial_energy_channels_db),
.spectrogram = std::move(analysis->spectrogram)};
return (PyObject*)result;
} catch (const std::bad_alloc&) {
// Technically, this object should be deleted with PyObject_Del, but
// XDECREF includes a null check which we want anyway.
Py_XDECREF((PyObject*)result);
return PyErr_NoMemory();
}
}

// Plain C++ function to compute distance between two zimtohrli::Analysis.
//
// Calls to Distance never need to be cleaned up (with e.g. delete or DECREF)
// afterwards.
PyObject* Distance(const zimtohrli::Zimtohrli& zimtohrli,
const zimtohrli::Analysis& analysis_a,
const zimtohrli::Analysis& analysis_b) {
const zimtohrli::Distance distance =
zimtohrli.Distance(false, analysis_a.spectrogram, analysis_b.spectrogram);
return PyFloat_FromDouble(distance.value);
}

PyObject* Pyohrli_analysis_distance(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
return BadArgument("not exactly 2 arguments provided");
}
if (!Py_IS_TYPE(args[0], &AnalysisType)) {
return BadArgument("argument 0 is not an Analysis instance");
}
if (!Py_IS_TYPE(args[1], &AnalysisType)) {
return BadArgument("argument 1 is not an Analysis instance");
}
return Distance(*self->zimtohrli, *((AnalysisObject*)args[0])->analysis,
*((AnalysisObject*)args[1])->analysis);
return signal_array;
}

PyObject* Pyohrli_distance(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 2) {
return BadArgument("not exactly 2 arguments provided");
}
const std::optional<zimtohrli::Analysis> analysis_a =
Analyze(*self->zimtohrli, args[0]);
if (!analysis_a.has_value()) {
const std::optional<hwy::AlignedNDArray<float, 1>> signal_a =
CopyBuffer(args[0]);
if (!signal_a.has_value()) {
return nullptr;
}
const std::optional<zimtohrli::Analysis> analysis_b =
Analyze(*self->zimtohrli, args[1]);
if (!analysis_b.has_value()) {
const std::optional<hwy::AlignedNDArray<float, 1>> signal_b =
CopyBuffer(args[1]);
if (!signal_b.has_value()) {
return nullptr;
}
return Distance(*self->zimtohrli, analysis_a.value(), analysis_b.value());
const hwy::AlignedNDArray<float, 2> spectrogram_a =
self->zimtohrli->StreamingSpectrogram((*signal_a)[{}]);
const hwy::AlignedNDArray<float, 2> spectrogram_b =
self->zimtohrli->StreamingSpectrogram((*signal_b)[{}]);
const zimtohrli::Distance distance =
self->zimtohrli->Distance(false, spectrogram_a, spectrogram_b);
return PyFloat_FromDouble(distance.value);
}

PyMethodDef Pyohrli_methods[] = {
{"analyze", (PyCFunction)Pyohrli_analyze, METH_FASTCALL,
"Returns an analysis of the provided signal."},
{"analysis_distance", (PyCFunction)Pyohrli_analysis_distance, METH_FASTCALL,
"Returns the distance between the two provided analyses."},
{"distance", (PyCFunction)Pyohrli_distance, METH_FASTCALL,
"Returns the distance between the two provided signals."},
{nullptr} /* Sentinel */
Expand Down Expand Up @@ -263,15 +178,6 @@ PyMODINIT_FUNC PyInit__pyohrli(void) {
PyObject* m = PyModule_Create(&PyohrliModule);
if (m == nullptr) return nullptr;

if (PyType_Ready(&AnalysisType) < 0) {
Py_DECREF(m);
return nullptr;
}
if (PyModule_AddObjectRef(m, "Analysis", (PyObject*)&AnalysisType) < 0) {
Py_DECREF(m);
return nullptr;
}

if (PyType_Ready(&PyohrliType) < 0) {
Py_DECREF(m);
return nullptr;
Expand Down
44 changes: 0 additions & 44 deletions cpp/zimt/pyohrli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ def mos_from_zimtohrli(zimtohrli_distance: float) -> float:
return _pyohrli.MOSFromZimtohrli(zimtohrli_distance)


class Analysis:
"""Wrapper around C++ zimtohrli::Analysis."""

_cc_analysis: _pyohrli.Analysis


class Pyohrli:
"""Wrapper around C++ zimtohrli::Zimtohrli."""

Expand All @@ -45,44 +39,6 @@ def __init__(self, sample_rate: float):
"""
self._cc_pyohrli = _pyohrli.Pyohrli(sample_rate)

def analyze(self, signal: npt.ArrayLike) -> Analysis:
"""Analyzes a signal.

Args:
signal: The signal to analyze. A (num_samples,)-shaped array of floats
between -1 and 1. The expected playout intensity in dB SPL of a 1kHz
sine wave between -1 and 1 is defined by setting 'full_scale_sine_db' of
this Pyohrli instance.

Returns:
An Analysis instance containing a psychoacoustic analysis of the signal.
"""
result = Analysis()
# Disabling protected-access to avoid making Analysis._cc_pyohrli public.
result._cc_analysis = (
self._cc_pyohrli.analyze( # pylint: disable=protected-access
np.asarray(signal).astype(np.float32).ravel().data,
)
)
return result

def analysis_distance(self, analysis_a: Analysis, analysis_b: Analysis) -> float:
"""Computes the distance between two psychoacoustic analyses.

Args:
analysis_a: An Analysis instance to compare.
analysis_b: Another Analysis instance to compare with.

Returns:
The Zimtohrli distance between the two analyses.
"""
return self._cc_pyohrli.analysis_distance(
# Disabling protected-access to avoid making Analysis._cc_pyohrli
# public.
analysis_a._cc_analysis, # pylint: disable=protected-access
analysis_b._cc_analysis, # pylint: disable=protected-access
)

def distance(self, signal_a: npt.ArrayLike, signal_b: npt.ArrayLike) -> float:
"""Computes the distance between two signals.

Expand Down
6 changes: 1 addition & 5 deletions cpp/zimt/pyohrli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,7 @@ def test_distance(self, a_hz: float, b_hz: float, distance: float):
sample_rate = 48000.0
metric = pyohrli.Pyohrli(sample_rate)
signal_a = np.sin(np.linspace(0.0, np.pi * 2 * a_hz, int(sample_rate)))
analysis_a = metric.analyze(signal_a)
signal_b = np.sin(np.linspace(0.0, np.pi * 2 * b_hz, int(sample_rate)))
analysis_b = metric.analyze(signal_b)
analysis_distance = metric.analysis_distance(analysis_a, analysis_b)
self.assertLess(abs(analysis_distance - distance), 1e-3)
distance = metric.distance(signal_a, signal_b)
self.assertLess(abs(distance - distance), 1e-3)

Expand All @@ -69,7 +65,7 @@ def test_nyquist_threshold(self):
signal = np.sin(np.linspace(0.0, np.pi * 2 * 440.0, int(sample_rate)))
# This would crash the program if pyohrli.cc didn't limit the upper
# threshold to half the sample rate.
metric.analyze(signal)
metric.distance(signal, signal)

@parameterize(
dict(zimtohrli_distance=0.0, mos=5.0),
Expand Down
49 changes: 49 additions & 0 deletions cpp/zimt/zimtohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,55 @@ Distance Zimtohrli::Distance(
}
}

hwy::AlignedNDArray<float, 2> Zimtohrli::StreamingSpectrogram(
hwy::Span<const float> signal) const {
hwy::AlignedNDArray<float, 2> chunk_energy_db_buffer(
{1, cam_filterbank->filter.Size()});
hwy::AlignedNDArray<float, 2> chunk_partial_energy_db_buffer(
{1, cam_filterbank->filter.Size()});
hwy::AlignedNDArray<float, 2> chunk_spectrogram_buffer(
{1, cam_filterbank->filter.Size()});
const size_t samples_per_chunk =
static_cast<size_t>(cam_filterbank->sample_rate / perceptual_sample_rate);
hwy::AlignedNDArray<float, 2> chunk_channels_buffer(
{samples_per_chunk, cam_filterbank->filter.Size()});
hwy::AlignedNDArray<float, 2> spectrogram(
{static_cast<size_t>(signal.size() / samples_per_chunk),
cam_filterbank->filter.Size()});
FilterbankState filter_state = cam_filterbank->filter.NewState();
for (size_t step = 0; (step + 1) * samples_per_chunk < signal.size();
++step) {
cam_filterbank->filter.Filter(
hwy::Span<const float>(signal.data() + step * samples_per_chunk,
samples_per_chunk),
filter_state, chunk_channels_buffer);
ComputeEnergy(chunk_channels_buffer, chunk_energy_db_buffer);
ToDb(chunk_energy_db_buffer, full_scale_sine_db, epsilon,
chunk_energy_db_buffer);
if (apply_masking) {
masking.CutFullyMasked(chunk_energy_db_buffer, cam_filterbank->cam_delta,
chunk_partial_energy_db_buffer);
} else {
hwy::CopyBytes(chunk_energy_db_buffer.data(),
chunk_partial_energy_db_buffer.data(),
chunk_energy_db_buffer.memory_size() * sizeof(float));
}
if (apply_loudness) {
loudness.PhonsFromSPL(chunk_partial_energy_db_buffer,
cam_filterbank->thresholds_hz,
chunk_spectrogram_buffer);
} else {
hwy::CopyBytes(
chunk_partial_energy_db_buffer.data(),
chunk_spectrogram_buffer.data(),
chunk_partial_energy_db_buffer.memory_size() * sizeof(float));
}
hwy::CopyBytes(chunk_spectrogram_buffer.data(), spectrogram[{step}].data(),
chunk_spectrogram_buffer.memory_size() * sizeof(float));
}
return spectrogram;
}

void Zimtohrli::Spectrogram(
hwy::Span<const float> signal, FilterbankState& state,
hwy::AlignedNDArray<float, 2>& channels,
Expand Down
16 changes: 15 additions & 1 deletion cpp/zimt/zimtohrli.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,27 @@ struct Zimtohrli {
hwy::AlignedNDArray<float, 2>& partial_energy_channels_db,
hwy::AlignedNDArray<float, 2>& spectrogram) const;

// Spectrogram without chunk processing.
// Spectrogram defaulting to a fresh filter bank state.
void Spectrogram(hwy::Span<const float> signal,
hwy::AlignedNDArray<float, 2>& channels,
hwy::AlignedNDArray<float, 2>& energy_channels_db,
hwy::AlignedNDArray<float, 2>& partial_energy_channels_db,
hwy::AlignedNDArray<float, 2>& spectrogram) const;

// Memory thrifty spectrogram calculation that computes one chunk (samples per
// perceptual sample rate period) at a time.
//
// Since it only computes one chunk at a time it only needs a chunk of
// channels, energy_channels_db, and partial_energy_channels_db that it
// allocates and reuses.
//
// signal is a span of audio samples between -1 and 1.
//
// Returns a (num_downscaled_samples, num_channels)-shaped array of Phons
// values reprecenting the perceptual intensity of each channel.
hwy::AlignedNDArray<float, 2> StreamingSpectrogram(
hwy::Span<const float> signal) const;

// Returns the perceptual distance between the two spectrograms.
//
// spectrogram_a and spectrogram_b are (num_samples, num_channels)-shaped
Expand Down
Loading
Loading