Skip to content

Commit

Permalink
~0.5 % improvement
Browse files Browse the repository at this point in the history
|Score type |MSE               |Min score         |Max score         |Mean score        |
|-----------|------------------|------------------|------------------|------------------|
|Zimtohrli  |0.077647743787295 |0.608127552926780 |0.829014091535645 |0.728679787505096 |
|ViSQOL     |0.115330916105424 |0.520833375452983 |0.801480831107469 |0.675101633981268 |
|2f         |0.129541391104905 |0.484687555319526 |0.797475783883375 |0.661870345773127 |
|PESQ       |0.147425552045669 |0.342342966279351 |0.841271127756762 |0.647128996775172 |
|CDPAM      |0.153471222942756 |0.441558428344727 |0.728779141125759 |0.620699318941738 |
|PARLAQ     |0.185057687192323 |0.445261140223642 |0.784370761057963 |0.587162756572532 |
|AQUA       |0.223207996944378 |0.331645933512413 |0.739286336419790 |0.547804951221731 |
|PEAQB      |0.225217321572038 |0.278744167467764 |0.851011116004117 |0.553935720513487 |
|DPAM       |0.315810440183130 |0.186717781679534 |0.690564701717118 |0.460415212267967 |
|WARP-Q     |0.339686211572685 |0.067600137543649 |0.777119464646524 |0.475793617709890 |
|GVPMOS     |0.412937133868407 |0.006851162794410 |0.783946603687895 |0.412912222208318 |
  • Loading branch information
jyrkialakuijala committed Jul 3, 2024
1 parent f471890 commit 780adbe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 55 deletions.
85 changes: 30 additions & 55 deletions cpp/zimt/fourier_bank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ float Loudness(int k, float val) {
{ 0.354, -3.5, 12.3, }, // 12500
};
const float *vals = &pars[k / 5][0];
static float constant1 = 32.617115877848235;
static float constant2 = 27.588336629470223;
static float constant1 = 33.10987678814273;
static float constant2 = 27.720606627666342;
val *= (constant1 + vals[1]) * (1.0 / constant2);
return val;
}

float SimpleDb(float energy) {
// ideally 78.3 db
static const float full_scale_sine_db = 77.4742570462833;
static const float full_scale_sine_db = 77.47500423191678;
static const float exp_full_scale_sine_db = exp(full_scale_sine_db);
// epsilon, but the biggest one you saw (~4.95e23)
static const float epsilon = 1.0033294789821357e-09 * exp_full_scale_sine_db;
Expand All @@ -71,49 +71,35 @@ float SimpleDb(float energy) {

void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
size_t out_ix) {
for (int k = 0; k < kNumRotators; ++k) {
float v = SimpleDb(mul * channels[{out_ix}][k]);
channels[{out_ix}][k] = Loudness(k, v);
}
double masker = 0.0;
static const double octaves_in_20_to_20000 = log(20000/20.)/log(2);
static const double octaves_per_rot =
octaves_in_20_to_20000 / float(kNumRotators - 1);
static const double masker_step_per_octave_up_0 = 14.713710868625487;
static const double masker_step_per_octave_up_1 = 21.82837106569622;
static const double masker_step_per_octave_up_2 = 6.398379411619291;
static const double masker_step_per_octave_up_0 = 16.495427054351488;
static const double masker_step_per_octave_up_1 = 22.552019717473833;
static const double masker_step_per_octave_up_2 = 20.082159717473832;
static const double masker_step_per_rot_up_0 = octaves_per_rot * masker_step_per_octave_up_0;
static const double masker_step_per_rot_up_1 = octaves_per_rot * masker_step_per_octave_up_1;
static const double masker_step_per_rot_up_2 = octaves_per_rot * masker_step_per_octave_up_2;
static const double masker_gap_up = 20.73060456058724;
static const float maskingStrengthUp = 0.13574932102981796;
static const double masker_gap_up = 19.140338374861235;
static const float maskingStrengthUp = 0.1252262923615547;
static const float up_blur = 0.8738593591692092;
static const float fraction_up = 1.0189925926484509;

static const double masker_step_per_octave_down = 53.404905982421795;
static const double masker_step_per_octave_down = 42.41172783112732;
static const double masker_step_per_rot_down = octaves_per_rot * masker_step_per_octave_down;
static const double masker_gap_down = 18.944464865908305;
static const float maskingStrengthDown = 0.18023700284914426;
static const double masker_gap_down = 19.59079250393617;
static const float maskingStrengthDown = 0.19329999999999992;
static const float down_blur = 0.714425315233319;

static const float min_limit = -11.397341001787765;
static const float fraction_down = 1.0198983999315951;

static const float temporal0 = 0.0999047126828788;
static const float temporal1 = 0.08127754119644388;
static const float temporal2 = 0.009595578765516729;
static float weightp = 0.17921002086852425;
static float weightm = 0.04999327787335808;

static float mask_k = 0.017060747885286255;

// Scan frequencies from bottom to top, let lower frequencies to mask higher frequencies.
// 'masker' maintains the masking envelope from one bin to next.
static const float temporal_masker0 = 0.13104546362447728;
static const float temporal_masker1 = 0.09719740670406614;
static const float temporal_masker2 = -0.03085233735225447;

for (int k = 0; k < kNumRotators; ++k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
float v = SimpleDb(mul * channels[{out_ix}][k]);
if (v < min_limit) {
v = min_limit;
}
Expand All @@ -124,10 +110,12 @@ void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
if (masker < v2) {
masker = v2;
}
float mask = fraction_up * masker - masker_gap_up;
float mask = masker - masker_gap_up;

if (v < mask) {
v = maskingStrengthUp * mask + (1.0 - maskingStrengthUp) * v;
}

channels[{out_ix}][k] = v;
if (3 * k < kNumRotators) {
masker -= masker_step_per_rot_up_0;
Expand All @@ -142,46 +130,33 @@ void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
masker = 0.0;
for (int k = kNumRotators - 1; k >= 0; --k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
float v2 = (1 - down_blur) * v2 + down_blur * v;
if (k == kNumRotators - 1) {
v2 = v;
}
if (masker < v) {
masker = v;
}
float mask = fraction_down * masker - masker_gap_down;
float mask = masker - masker_gap_down;
if (v < mask) {
v = maskingStrengthDown * mask + (1.0 - maskingStrengthDown) * v;
}
channels[{out_ix}][k] = v;
masker -= masker_step_per_rot_down;
}
for (int k = 0; k < kNumRotators; ++k) {
channels[{out_ix}][k] = Loudness(k, channels[{out_ix}][k]);
}
// temporal masker
if (out_ix >= 3) {
for (int k = 0; k < kNumRotators; ++k) {
float m = (temporal0 * channels[{out_ix - 1}][k] +
temporal1 * channels[{out_ix - 2}][k] +
temporal2 * channels[{out_ix - 3}][k]) / (temporal0 + temporal1 + temporal2);
if (m > channels[{out_ix}][k]) {
channels[{out_ix}][k] -= weightp * (m - channels[{out_ix}][k]);
} else {
channels[{out_ix}][k] -= weightm * (m - channels[{out_ix}][k]);
}
/*
// todo(jyrki): explore with this
static const float temporal_masker0 = 0.1387454636244773;
channels[{out_ix}][k] -=
temporal_masker0 * (channels[{out_ix - 1}][k] - channels[{out_ix}][k]);
static const float temporal_masker1 = 0.08715440670406614;
channels[{out_ix}][k] -=
temporal_masker1 * (channels[{out_ix - 2}][k] - channels[{out_ix}][k]);
static const float temporal_masker2 = -0.03785233735225447;
channels[{out_ix}][k] -=
temporal_masker2 * (channels[{out_ix - 3}][k] - channels[{out_ix}][k]);
*/
float v0 = (channels[{out_ix - 1}][k] - channels[{out_ix}][k]);
float v1 = (channels[{out_ix - 2}][k] - channels[{out_ix}][k]);
float v2 = (channels[{out_ix - 3}][k] - channels[{out_ix}][k]);

channels[{out_ix}][k] -= temporal_masker0 * v0 +
temporal_masker1 * v1 +
temporal_masker2 * v2;
}
}
}
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.

0 comments on commit 780adbe

Please sign in to comment.