Skip to content

Commit

Permalink
Merge pull request #122 from jyrkialakuijala/tabuli
Browse files Browse the repository at this point in the history
~5% worse, but faster and simpler
  • Loading branch information
jyrkialakuijala authored Jul 11, 2024
2 parents 59cc535 + 6000340 commit b112037
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 97 deletions.
191 changes: 97 additions & 94 deletions cpp/zimt/fourier_bank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,121 +71,124 @@ float SimpleDb(float energy) {
return kMul * log(energy + epsilon);
}

void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
size_t out_ix) {
float masker_down[kNumRotators];
for (int k = 0; k < kNumRotators; ++k) {
float v = SimpleDb(mul * channels[{out_ix}][k]);
channels[{out_ix}][k] = Loudness(k, v);
void PrepareMasker(hwy::AlignedNDArray<float, 2>& channels,
float *masker,
size_t out_ix) {
if (out_ix < 3) {
for (int k = 0; k < kNumRotators; ++k) {
masker[k] = channels[{out_ix}][k];
}
} else {
// convolve in time and freq, 5 freq bins, 3 time bins
static const double c[12] = {
0.011551012731481482,
0.02009898726851852,
0.27419898726851855,

-0.04009898726851849,
0.3270268229166667,
0.6400989872685185,

0.36397005208333333,
0.6505010127314814,
0.8000989872685186,

-0.15930101273148148,
1.5483130497685185,
8.31009898726852,
};
static const float div = 1.0 / (2*(c[0]+c[1]+c[2]+c[3]+c[4]+c[5]+c[6]+c[7]+c[8])+c[9]+c[10]+c[11]);
for (int k = 0; k < kNumRotators; ++k) {
int prev3 = std::max(0, k - 3);
int prev2 = std::max(0, k - 2);
int prev1 = std::max(0, k - 1);
int currk = k;
int next1 = std::min<int>(kNumRotators - 1, k + 1);
int next2 = std::min<int>(kNumRotators - 1, k + 2);
int next3 = std::min<int>(kNumRotators - 1, k + 3);
size_t oi2 = out_ix - 2;
size_t oi1 = out_ix - 1;
size_t oi0 = out_ix - 0;

float v =
channels[{oi2}][prev3] * c[0] + channels[{oi1}][prev3] * c[1] + channels[{oi0}][prev3] * c[2] +
channels[{oi2}][prev2] * c[3] + channels[{oi1}][prev2] * c[4] + channels[{oi0}][prev2] * c[5] +
channels[{oi2}][prev1] * c[6] + channels[{oi1}][prev1] * c[7] + channels[{oi0}][prev1] * c[8] +
channels[{oi2}][currk] * c[9] + channels[{oi1}][currk] * c[10] + channels[{oi0}][currk] * c[11] +
channels[{oi2}][next1] * c[6] + channels[{oi1}][next1] * c[7] + channels[{oi0}][next1] * c[8] +
channels[{oi2}][next2] * c[3] + channels[{oi1}][next2] * c[4] + channels[{oi0}][next2] * c[5] +
channels[{oi2}][next3] * c[0] + channels[{oi1}][next3] * c[1] + channels[{oi0}][next3] * c[2];

masker[k] = v * div;
}
}
double masker = 0.0;
static const double octaves_in_20_to_20000 = log(20000/20.)/log(2);
static const double octaves_per_rot =
octaves_in_20_to_20000 / float(kNumRotators - 1);
static const double masker_step_per_octave_up_0 = 19.53945781131615;
static const double masker_step_per_octave_up_1 = 24.714118008386887;
static const double masker_step_per_octave_up_2 = 6.449301354309956;
static const double masker_step_per_octave_up_0 = 20.54547806594578;
static const double masker_step_per_octave_up_1 = 24.608097753757256;
static const double masker_step_per_octave_up_2 = 6.0;
static const double masker_step_per_rot_up_0 = octaves_per_rot * masker_step_per_octave_up_0;
static const double masker_step_per_rot_up_1 = octaves_per_rot * masker_step_per_octave_up_1;
static const double masker_step_per_rot_up_2 = octaves_per_rot * masker_step_per_octave_up_2;
static const double masker_gap_up = 21.309406898722074;
static const float maskingStrengthUp = 0.2056434702527141;
static const float up_blur = 0.9442717063037425;
static const float fraction_up = 1.1657467617827404;

static const double masker_step_per_octave_down = 53.40273959309446;
static const double masker_step_per_octave_down = 53.40075984772409;
static const double masker_step_per_rot_down = octaves_per_rot * masker_step_per_octave_down;
static const double masker_gap_down = 19.08401096304284;
static const float maskingStrengthDown = 0.18030917038808858;
static const float down_blur = 0.7148792180987857;
// propagate masker up
float mask = 0;
for (int k = 0; k < kNumRotators; ++k) {
float v = masker[k];
if (mask < v) {
mask = v;
}
masker[k] = std::max<float>(masker[k], mask);
if (3 * k < kNumRotators) {
mask -= masker_step_per_rot_up_0;
} else if (3 * k < 2 * kNumRotators) {
mask -= masker_step_per_rot_up_1;
} else {
mask -= masker_step_per_rot_up_2;
}
}
// propagate masker down
mask = 0;
for (int k = kNumRotators - 1; k >= 0; --k) {
float v = masker[k];
if (mask < v) {
mask = v;
}
masker[k] = std::max<float>(masker[k], mask);
mask -= masker_step_per_rot_down;
}
}

static const float min_limit = -11.3968870989223;
static const float fraction_down = 1.0197608300379997;
void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
size_t out_ix) {
float masker[kNumRotators];
for (int k = 0; k < kNumRotators; ++k) {
float v = SimpleDb(mul * channels[{out_ix}][k]);
channels[{out_ix}][k] = Loudness(k, v);
}
PrepareMasker(channels, &masker[0], out_ix);

static const float temporal0 = 0.09979167061501665;
static const float temporal1 = 0.14429505133534495;
static const float temporal2 = 0.009228598592129168;
static const float weightp = 0.1792443302507868;
static const float weightm = 0.7954490998745948;

static const float mask_k = 0.08709005149742773;
static const double masker_gap = 20.716199363425925;
static const float maskingStrength = 0.22591336897956596;

static const float min_limit = -11.3968870989223;

// Scan frequencies from bottom to top, let lower frequencies to mask higher frequencies.
// 'masker' maintains the masking envelope from one bin to next.
for (int k = 0; k < kNumRotators; ++k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
double mask = masker[k] - masker_gap;
if (v < min_limit) {
v = min_limit;
}
float v2 = (1 - up_blur) * v2 + up_blur * v;
if (k == 0) {
v2 = v;
}
if (masker < v2) {
masker = v2;
}
float mask = fraction_up * masker - masker_gap_up;
if (v < mask) {
v = maskingStrengthUp * mask + (1.0 - maskingStrengthUp) * v;
}
channels[{out_ix}][k] = v;
if (3 * k < kNumRotators) {
masker -= masker_step_per_rot_up_0;
} else if (3 * k < 2 * kNumRotators) {
masker -= masker_step_per_rot_up_1;
} else {
masker -= masker_step_per_rot_up_2;
}
}
// Scan frequencies from top to bottom, let higher frequencies to mask lower frequencies.
// 'masker' maintains the masking envelope from one bin to next.
masker = 0.0;
for (int k = kNumRotators - 1; k >= 0; --k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
float v2 = (1 - down_blur) * v2 + down_blur * v;
if (k == kNumRotators - 1) {
v2 = v;
}
if (masker < v) {
masker = v;
}
float mask = fraction_down * masker - masker_gap_down;
if (v < mask) {
v = maskingStrengthDown * mask + (1.0 - maskingStrengthDown) * v;
v = maskingStrength * mask + (1.0 - maskingStrength) * v;
}
channels[{out_ix}][k] = v;
masker -= masker_step_per_rot_down;
}
// temporal masker
if (out_ix >= 3) {
for (int k = 0; k < kNumRotators; ++k) {
float m = (temporal0 * channels[{out_ix - 1}][k] +
temporal1 * channels[{out_ix - 2}][k] +
temporal2 * channels[{out_ix - 3}][k]) / (temporal0 + temporal1 + temporal2);
if (m > channels[{out_ix}][k]) {
channels[{out_ix}][k] -= weightp * (m - channels[{out_ix}][k]);
} else {
channels[{out_ix}][k] -= weightm * (m - channels[{out_ix}][k]);
}
/*
// todo(jyrki): explore with this
static const float temporal_masker0 = 0.1387454636244773;
channels[{out_ix}][k] -=
temporal_masker0 * (channels[{out_ix - 1}][k] - channels[{out_ix}][k]);
static const float temporal_masker1 = 0.08715440670406614;
channels[{out_ix}][k] -=
temporal_masker1 * (channels[{out_ix - 2}][k] - channels[{out_ix}][k]);
static const float temporal_masker2 = -0.03785233735225447;
channels[{out_ix}][k] -=
temporal_masker2 * (channels[{out_ix - 3}][k] - channels[{out_ix}][k]);
*/
}
}
}

Expand Down Expand Up @@ -254,8 +257,8 @@ Rotators::Rotators(int num_channels, std::vector<float> frequency,
window[i] = std::pow(kWindow, bw * kBandwidthMagic);
float windowM1 = 1.0f - window[i];
float f = frequency[i] * 2.0f * M_PI / sample_rate;
static const float full_scale_sine_db = exp(75.27858635739499);
const float gainer = 2.0f * sqrt(full_scale_sine_db);
static const float full_scale_sine_db = exp(76.66488071851488);
const float gainer = sqrt(full_scale_sine_db);
gain[i] = gainer * filter_gains[i] * pow(windowM1, 3.0);
rot[0][i] = float(std::cos(f));
rot[1][i] = float(-std::sin(f));
Expand Down
2 changes: 1 addition & 1 deletion cpp/zimt/fourier_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace tabuli {

constexpr int64_t kNumRotators = 150;
constexpr int64_t kNumRotators = 128;

struct PerChannel {
// [0..1] is for real and imag of 1st leaking accumulation
Expand Down
4 changes: 2 additions & 2 deletions cpp/zimt/zimtohrli.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ struct Zimtohrli {
std::optional<CamFilterbank> cam_filterbank;

// The window in perceptual_sample_rate time steps when compting the NSIM.
size_t nsim_step_window = 16;
size_t nsim_step_window = 8;

// The window in channels when computing the NSIM.
size_t nsim_channel_window = 30;
size_t nsim_channel_window = 16;

// The window of the dynamic time warp that matches audio signals.
//
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.

0 comments on commit b112037

Please sign in to comment.