Skip to content

Commit

Permalink
Fixed some performance issues in DTW.
Browse files Browse the repository at this point in the history
Co-authored-by: Jyrki Alakuijala <jyrki.alakuijala@gmail.com>
  • Loading branch information
2 people authored and zond committed Jun 13, 2024
1 parent 115c4ac commit da39b39
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 25 deletions.
31 changes: 6 additions & 25 deletions cpp/zimt/dtw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,18 @@ float HwyMin(hwy::Span<const float> span) {
return min;
}

float HwyDeltaNorm(hwy::Span<const float> span_a, hwy::Span<const float> span_b,
float order, float max) {
if (max == 0) {
return 0;
}

float HwyDeltaNorm(hwy::Span<const float> span_a,
hwy::Span<const float> span_b) {
CHECK_EQ(span_a.size(), span_b.size());

const Vec order_vec = Set(d, order);
const Vec max_reciprocal = Div(Set(d, 1), Set(d, max));
double sum = 0;
for (size_t index = 0; index < span_a.size(); index += Lanes(d)) {
const Vec delta =
Sub(Load(d, span_a.data() + index), Load(d, span_b.data() + index));
const Vec downscaled_values = Mul(delta, max_reciprocal);
const Vec pows = Exp(d, Mul(order_vec, Log(d, downscaled_values)));
sum += static_cast<double>(ReduceSum(d, pows));
const Vec square = Mul(delta, delta);
sum += static_cast<double>(ReduceSum(d, square));
}
return static_cast<float>(max *
std::pow(sum, 1 / static_cast<double>(order)));
return static_cast<float>(std::sqrt(sum));
}

} // namespace HWY_NAMESPACE
Expand Down Expand Up @@ -120,20 +112,12 @@ std::vector<std::pair<size_t, size_t>> DTWSlice(
hwy::AlignedNDArray<float, 2>& cost_matrix) {
CHECK_EQ(cost_matrix.shape()[0], spec_a.shape()[0]);
CHECK_EQ(cost_matrix.shape()[1], spec_b.shape()[0]);
std::vector<float> max_a(spec_a.shape()[0]);
std::vector<float> min_a(spec_a.shape()[0]);
std::vector<float> max_b(spec_b.shape()[0]);
std::vector<float> min_b(spec_b.shape()[0]);
for (size_t spec_b_index = 0; spec_b_index < spec_b.shape()[0];
++spec_b_index) {
max_b[spec_b_index] = HWY_DYNAMIC_DISPATCH(HwyMax)(spec_b[{spec_b_index}]);
min_b[spec_b_index] = HWY_DYNAMIC_DISPATCH(HwyMin)(spec_b[{spec_b_index}]);
cost_matrix[{0}][spec_b_index] = std::numeric_limits<float>::infinity();
}
for (size_t spec_a_index = 1; spec_a_index < spec_a.shape()[0];
++spec_a_index) {
max_a[spec_a_index] = HWY_DYNAMIC_DISPATCH(HwyMax)(spec_a[{spec_a_index}]);
min_a[spec_a_index] = HWY_DYNAMIC_DISPATCH(HwyMin)(spec_a[{spec_a_index}]);
for (size_t spec_b_index = 0; spec_b_index < spec_b.shape()[0];
++spec_b_index) {
cost_matrix[{spec_a_index}][spec_b_index] =
Expand All @@ -145,11 +129,8 @@ std::vector<std::pair<size_t, size_t>> DTWSlice(
++spec_a_index) {
for (size_t spec_b_index = 1; spec_b_index < spec_b.shape()[0];
++spec_b_index) {
const float max_delta =
std::max(std::abs(max_a[spec_a_index] - min_b[spec_b_index]),
std::abs(max_b[spec_b_index] - min_a[spec_a_index]));
const float cost = HWY_DYNAMIC_DISPATCH(HwyDeltaNorm)(
spec_a[{spec_a_index}], spec_b[{spec_b_index}], 2.0, max_delta);
spec_a[{spec_a_index}], spec_b[{spec_b_index}]);
cost_matrix[{spec_a_index}][spec_b_index] =
cost +
std::min(cost_matrix[{spec_a_index - 1}][spec_b_index - 1],
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.

0 comments on commit da39b39

Please sign in to comment.