Skip to content

Commit

Permalink
Made the MOS mapping parameters mutable.
Browse files Browse the repository at this point in the history
- Made the Zimtohrli instance hold a MOSMapper containing the parameters.
- Made all MOS mappings use the current Zimtohrli instance to map to MOS.
- Made the mutable MOS parameters modifiable from Go.
  • Loading branch information
Martin Bruse committed Jun 18, 2024
1 parent 80dc2be commit fc232d9
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 69 deletions.
14 changes: 7 additions & 7 deletions cpp/zimt/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ std::ostream& operator<<(std::ostream& outs, const DistanceData& data) {
return outs;
}

float GetMetric(float zimtohrli_score) {
float GetMetric(const zimtohrli::Zimtohrli& z, float zimtohrli_score) {
if (absl::GetFlag(FLAGS_output_zimtohrli_distance)) {
return zimtohrli_score;
}
return MOSFromZimtohrli(zimtohrli_score);
return z.mos_mapper.Map(zimtohrli_score);
}

int Main(int argc, char* argv[]) {
Expand Down Expand Up @@ -340,16 +340,16 @@ int Main(int argc, char* argv[]) {
z.Distance(false, file_a_spectrograms[channel_index], spectrogram_b)
.value;
if (per_channel) {
std::cout << GetMetric(distance) << std::endl;
std::cout << GetMetric(z, distance) << std::endl;
} else {
sum_of_squares += distance * distance;
}
}
if (!per_channel) {
for (int file_b_index = 0; file_b_index < file_b_vector.size();
++file_b_index) {
std::cout << GetMetric(std::sqrt(sum_of_squares /
float(file_a->Info().channels)))
std::cout << GetMetric(z, std::sqrt(sum_of_squares /
float(file_a->Info().channels)))
<< std::endl;
}
}
Expand Down Expand Up @@ -413,13 +413,13 @@ int Main(int argc, char* argv[]) {
const float distance = phons_channel_distance.distance.value;
sum_of_squares += distance * distance;

std::cout << " Channel MOS: " << MOSFromZimtohrli(distance)
std::cout << " Channel MOS: " << z.mos_mapper.Map(distance)
<< std::endl;
}
const float zimtohrli_file_distance =
std::sqrt(sum_of_squares / float(comparison.analysis_a.size()));
std::cout << " File distance: " << zimtohrli_file_distance << std::endl;
std::cout << " File MOS: " << MOSFromZimtohrli(zimtohrli_file_distance)
std::cout << " File MOS: " << z.mos_mapper.Map(zimtohrli_file_distance)
<< std::endl;
}
return 0;
Expand Down
24 changes: 17 additions & 7 deletions cpp/zimt/goohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ int NumLoudnessTFParams() {
return NUM_LOUDNESS_T_F_PARAMS;
}

int NumMOSMapperParams() {
CHECK_EQ(NUM_MOS_MAPPER_PARAMS, zimtohrli::MOSMapper{}.params.size());
return NUM_MOS_MAPPER_PARAMS;
}

EnergyAndMaxAbsAmplitude Measure(const float* signal, int size) {
hwy::AlignedNDArray<float, 1> signal_array({static_cast<size_t>(size)});
hwy::CopyBytes(signal, signal_array.data(), size * sizeof(float));
Expand All @@ -67,11 +72,12 @@ EnergyAndMaxAbsAmplitude NormalizeAmplitude(float max_abs_amplitude,
.MaxAbsAmplitude = measurements.max_abs_amplitude};
}

float MOSFromZimtohrli(float zimtohrli_distance) {
return zimtohrli::MOSFromZimtohrli(zimtohrli_distance);
float MOSFromZimtohrli(const Zimtohrli zimtohrli, float zimtohrli_distance) {
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
return z->mos_mapper.Map(zimtohrli_distance);
}

Zimtohrli CreateZimtohrli(ZimtohrliParameters params) {
Zimtohrli CreateZimtohrli(const ZimtohrliParameters params) {
zimtohrli::Cam cam{.minimum_bandwidth_hz = params.FrequencyResolution,
.filter_order = params.FilterOrder,
.filter_pass_band_ripple = params.FilterPassBandRipple,
Expand All @@ -88,9 +94,9 @@ void FreeZimtohrli(Zimtohrli zimtohrli) {
delete static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
}

float Distance(Zimtohrli zimtohrli, float* data_a, int size_a, float* data_b,
int size_b) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
float Distance(const Zimtohrli zimtohrli, float* data_a, int size_a,
float* data_b, int size_b) {
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
hwy::AlignedNDArray<float, 1> signal_a({static_cast<size_t>(size_a)});
hwy::CopyBytes(data_a, signal_a.data(), size_a * sizeof(float));
hwy::AlignedNDArray<float, 1> signal_b({static_cast<size_t>(size_b)});
Expand All @@ -105,7 +111,7 @@ float Distance(Zimtohrli zimtohrli, float* data_a, int size_a, float* data_b,
}

ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
ZimtohrliParameters result;
result.SampleRate = z->cam_filterbank->sample_rate;
const hwy::AlignedNDArray<float, 2>& thresholds =
Expand Down Expand Up @@ -133,6 +139,8 @@ ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) {
sizeof(result.LoudnessLUParams));
std::memcpy(result.LoudnessTFParams, z->loudness.t_f_params.data(),
sizeof(result.LoudnessTFParams));
std::memcpy(result.MOSMapperParams, z->mos_mapper.params.data(),
sizeof(result.MOSMapperParams));
return result;
}

Expand All @@ -157,6 +165,8 @@ void SetZimtohrliParameters(Zimtohrli zimtohrli,
sizeof(parameters.LoudnessLUParams));
std::memcpy(z->loudness.t_f_params.data(), parameters.LoudnessTFParams,
sizeof(parameters.LoudnessTFParams));
std::memcpy(z->mos_mapper.params.data(), parameters.MOSMapperParams,
sizeof(parameters.MOSMapperParams));
}

ZimtohrliParameters DefaultZimtohrliParameters(float sample_rate) {
Expand Down
11 changes: 3 additions & 8 deletions cpp/zimt/mos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,14 @@ namespace zimtohrli {

namespace {

const std::array<float, 3> params = {1.000e+00, -7.449e-09, 3.344e+00};

float sigmoid(float x) {
float sigmoid(const std::array<float, 3>& params, float x) {
return params[0] / (params[1] + std::exp(params[2] * x));
}

const float zero_crossing_reciprocal = 1.0 / sigmoid(0);

} // namespace

// Optimized using `mos_mapping.ipynb`.
float MOSFromZimtohrli(float zimtohrli_distance) {
return 1.0 + 4.0 * sigmoid(zimtohrli_distance) * zero_crossing_reciprocal;
float MOSMapper::Map(float zimtohrli_distance) const {
return 1.0 + 4.0 * sigmoid(params, zimtohrli_distance) / sigmoid(params, 0);
}

} // namespace zimtohrli
27 changes: 20 additions & 7 deletions cpp/zimt/mos.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@
#ifndef CPP_ZIMT_MOS_H_
#define CPP_ZIMT_MOS_H_

#include <array>

namespace zimtohrli {

// Returns a _very_approximate_ mean opinion score based on the
// provided Zimtohrli distance.
// This is calibrated using default settings of v0.1.5, with a
// minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz)
// of 5Hz and perceptual sample rate
// (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz.
float MOSFromZimtohrli(float zimtohrli_distance);
// Maps from Zimtohrli distance to MOS.
struct MOSMapper {
// Returns a _very_approximate_ mean opinion score based on the
// provided Zimtohrli distance.
//
// Computed by:
// s(x) = params[0] / (params[1] + e^(params[2] * x))
// MOS = 1 + 4 * s(distance)) / s(0)
//
// This is calibrated using default settings of v0.1.5, with a
// minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz)
// of 5Hz and perceptual sample rate
// (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz.
float Map(float zimtohrli_distance) const;

// Params used when mapping Zimtohrli distance to MOS.
std::array<float, 3> params = {1.000e+00, -7.449e-09, 3.344e+00};
};

} // namespace zimtohrli

Expand Down
3 changes: 2 additions & 1 deletion cpp/zimt/mos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ TEST(MOS, MOSFromZimtohrli) {
const std::vector<float> zimt_scores = {0, 0.1, 0.5, 0.7, 1.0};
const std::vector<float> mos = {5.0, 3.8630697727203369, 1.751483678817749,
1.3850023746490479, 1.1411819458007812};
const MOSMapper m;
for (size_t index = 0; index < zimt_scores.size(); ++index) {
ASSERT_NEAR(MOSFromZimtohrli(zimt_scores[index]), mos[index], 1e-2);
ASSERT_NEAR(m.Map(zimt_scores[index]), mos[index], 1e-2);
}
}

Expand Down
30 changes: 13 additions & 17 deletions cpp/zimt/pyohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,22 @@ PyObject* Pyohrli_distance(PyohrliObject* self, PyObject* const* args,
return PyFloat_FromDouble(distance.value);
}

PyObject* Pyohrli_mos_from_zimtohrli(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 1) {
return BadArgument("not exactly 1 argument provided");
}
return PyFloat_FromDouble(
self->zimtohrli->mos_mapper.Map(PyFloat_AsDouble(args[0])));
}

PyMethodDef Pyohrli_methods[] = {
{"distance", (PyCFunction)Pyohrli_distance, METH_FASTCALL,
"Returns the distance between the two provided signals."},
{"mos_from_zimtohrli", (PyCFunction)Pyohrli_mos_from_zimtohrli,
METH_FASTCALL,
"Returns an approximate mean opinion score based on the provided "
"Zimtohrli distance."},
{nullptr} /* Sentinel */
};

Expand All @@ -150,28 +163,11 @@ PyTypeObject PyohrliType = {
.tp_new = PyType_GenericNew,
};

PyObject* MOSFromZimtohrli(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 1) {
return BadArgument("not exactly 1 argument provided");
}
return PyFloat_FromDouble(
zimtohrli::MOSFromZimtohrli(PyFloat_AsDouble(args[0])));
}

static PyMethodDef PyohrliModuleMethods[] = {
{"MOSFromZimtohrli", (PyCFunction)MOSFromZimtohrli, METH_FASTCALL,
"Returns an approximate mean opinion score based on the provided "
"Zimtohrli distance."},
{NULL, NULL, 0, NULL},
};

PyModuleDef PyohrliModule = {
.m_base = PyModuleDef_HEAD_INIT,
.m_name = "pyohrli",
.m_doc = "Python wrapper around the C++ zimtohrli library.",
.m_size = -1,
.m_methods = PyohrliModuleMethods,
};

PyMODINIT_FUNC PyInit__pyohrli(void) {
Expand Down
9 changes: 4 additions & 5 deletions cpp/zimt/pyohrli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import _pyohrli


def mos_from_zimtohrli(zimtohrli_distance: float) -> float:
"""Returns an approximate mean opinion score based on the provided Zimtohrli distance."""
return _pyohrli.MOSFromZimtohrli(zimtohrli_distance)


class Pyohrli:
"""Wrapper around C++ zimtohrli::Zimtohrli."""

Expand Down Expand Up @@ -56,6 +51,10 @@ def distance(self, signal_a: npt.ArrayLike, signal_b: npt.ArrayLike) -> float:
np.asarray(signal_b).astype(np.float32).ravel().data,
)

def mos_from_zimtohrli(self, zimtohrli_distance: float) -> float:
"""Returns an approximate mean opinion score based on the provided Zimtohrli distance."""
return self._cc_pyohrli.mos_from_zimtohrli(zimtohrli_distance)

@property
def full_scale_sine_db(self) -> float:
"""Reference intensity for an amplitude 1.0 sine wave at 1kHz.
Expand Down
3 changes: 2 additions & 1 deletion cpp/zimt/pyohrli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_nyquist_threshold(self):
dict(zimtohrli_distance=1.0, mos=1.1411819458007812),
)
def test_mos_from_zimtohrli(self, zimtohrli_distance: float, mos: float):
metric = pyohrli.Pyohrli(48000.0)
self.assertAlmostEqual(
mos, pyohrli.mos_from_zimtohrli(zimtohrli_distance), places=3
mos, metric.mos_from_zimtohrli(zimtohrli_distance), places=3
)


Expand Down
4 changes: 4 additions & 0 deletions cpp/zimt/zimtohrli.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "zimt/cam.h"
#include "zimt/loudness.h"
#include "zimt/masking.h"
#include "zimt/mos.h"

namespace zimtohrli {

Expand Down Expand Up @@ -326,6 +327,9 @@ struct Zimtohrli {
// Perceptual intensity model.
Loudness loudness;

// MOS mapping model.
MOSMapper mos_mapper;

// Whether the masking model is applied when creating spectrograms.
bool apply_masking = true;

Expand Down
4 changes: 2 additions & 2 deletions go/bin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ func main() {
}

if *zimtohrli {
g := goohrli.New(zimtohrliParameters)
getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
return g.MOSFromZimtohrli(f)
}

if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil {
Expand All @@ -117,7 +118,6 @@ func main() {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = signalA.Rate
g := goohrli.New(zimtohrliParameters)
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
Expand Down
16 changes: 14 additions & 2 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ func main() {
optimizeNumSteps := flag.Float64("optimize_num_steps", 1000, "Number of steps for the simulated annealing.")
workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.")
failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.")
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" {
flag.Usage()
os.Exit(1)
}
Expand All @@ -88,10 +89,21 @@ func main() {
f.Sync()
}
}
err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog)
if err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog); err != nil {
log.Fatal(err)
}
}

if *optimizeMapping != "" {
bundles, err := data.OpenBundles(*optimizeMapping)
if err != nil {
log.Fatal(err)
}
params, err := bundles.OptimizeMapping()
if err != nil {
log.Fatal(err)
}
fmt.Println(params)
}

if *calculate != "" {
Expand Down
4 changes: 4 additions & 0 deletions go/data/study.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ func (r ReferenceBundles) Split(rng *rand.Rand, split float64) (ReferenceBundles
return left, right
}

func (r ReferenceBundles) OptimizeMapping() ([]float32, error) {
return nil, nil
}

// OptimizationEvent is a step in the optimization process.
type OptimizationEvent struct {
Parameters goohrli.Parameters
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.
Loading

0 comments on commit fc232d9

Please sign in to comment.