Skip to content

Commit 4f87724

Browse files
committed
simplify by using Array::from_fn
1 parent 4aefdf0 commit 4f87724

File tree

5 files changed

+14
-76
lines changed

5 files changed

+14
-76
lines changed

src/lib.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -976,17 +976,8 @@ pub trait SimdType<T, const N: usize> {
976976
#[must_use]
977977
fn from_array(array: [T; N]) -> Self;
978978

979-
/// Performs a binary operation corresponding elements
980-
/// on two SIMD types and returns the result.
981-
///
982-
/// This is useful for implementing
983-
/// operations that the compiler vectorizes but this library
984-
/// don't provide explicit support for.
985-
fn binary_op<FN: Fn(T, T) -> T>(self, rhs: Self, op: FN) -> Self;
986-
987-
/// performs a unary operation on each element of the SIMD type
988-
/// and returns the result.
989-
fn unary_op<FN: Fn(T) -> T>(self, op: FN) -> Self;
979+
/// provide same functionarlity as array from_fn
980+
fn from_fn<F: Fn(usize) -> T>(cb: F) -> Self;
990981
}
991982

992983
macro_rules! bulk_impl_const_rhs_op {

src/u32x4_.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -546,15 +546,7 @@ impl SimdType<u32, 4> for u32x4 {
546546
}
547547

548548
#[inline]
549-
fn binary_op<FN: Fn(u32, u32) -> u32>(self, rhs: Self, op: FN) -> Self {
550-
let a: [u32; 4] = cast(self);
551-
let b: [u32; 4] = cast(rhs);
552-
cast([op(a[0], b[0]), op(a[1], b[1]), op(a[2], b[2]), op(a[3], b[3])])
553-
}
554-
555-
#[inline]
556-
fn unary_op<FN: Fn(u32) -> u32>(self, op: FN) -> Self {
557-
let a: [u32; 4] = cast(self);
558-
cast([op(a[0]), op(a[1]), op(a[2]), op(a[3])])
549+
fn from_fn<F: Fn(usize) -> u32>(cb: F) -> Self {
550+
cast([cb(0), cb(1), cb(2), cb(3)])
559551
}
560552
}

src/u32x8_.rs

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -359,33 +359,7 @@ impl SimdType<u32, 8> for u32x8 {
359359
}
360360

361361
#[inline]
362-
fn binary_op<FN: Fn(u32, u32) -> u32>(self, rhs: Self, op: FN) -> Self {
363-
let a: [u32; 8] = cast(self);
364-
let b: [u32; 8] = cast(rhs);
365-
cast([
366-
op(a[0], b[0]),
367-
op(a[1], b[1]),
368-
op(a[2], b[2]),
369-
op(a[3], b[3]),
370-
op(a[4], b[4]),
371-
op(a[5], b[5]),
372-
op(a[6], b[6]),
373-
op(a[7], b[7]),
374-
])
375-
}
376-
377-
#[inline]
378-
fn unary_op<FN: Fn(u32) -> u32>(self, op: FN) -> Self {
379-
let a: [u32; 8] = cast(self);
380-
cast([
381-
op(a[0]),
382-
op(a[1]),
383-
op(a[2]),
384-
op(a[3]),
385-
op(a[4]),
386-
op(a[5]),
387-
op(a[6]),
388-
op(a[7]),
389-
])
362+
fn from_fn<F: Fn(usize) -> u32>(cb: F) -> Self {
363+
cast([cb(0), cb(1), cb(2), cb(3), cb(4), cb(5), cb(6), cb(7)])
390364
}
391365
}

tests/all_tests/t_common.rs

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,17 @@ pub fn test_binary_op<
1515
fn_scalar: FnScalar,
1616
fn_vector: FnVector,
1717
) {
18-
let mut expected = T::default();
19-
for i in 0..N {
20-
expected.as_mut_array()[i] = fn_scalar(a.as_array()[i], b.as_array()[i]);
21-
}
18+
let expected = T::from_fn(|i| fn_scalar(a.as_array()[i], b.as_array()[i]));
2219

2320
let actual = fn_vector(a, b);
2421

2522
// assert equality for manually calculated result
2623
assert_eq!(expected, actual, "scalar={:?} vector={:?}", expected, actual);
27-
28-
// assert equality using the binary_op method as well
29-
assert_eq!(
30-
expected,
31-
a.binary_op(b, fn_scalar),
32-
"scalar={:?} binary_op={:?}",
33-
expected,
34-
actual
35-
);
3624
}
3725

3826
pub fn test_unary_op<
39-
T: SimdType<V, N> + Default + PartialEq + std::fmt::Debug + Copy,
40-
V: Copy,
27+
T: SimdType<V, N> + PartialEq + std::fmt::Debug + Copy,
28+
V: Copy + PartialEq + std::fmt::Debug,
4129
FnVector: Fn(T) -> T,
4230
FnScalar: Fn(V) -> V,
4331
const N: usize,
@@ -46,22 +34,14 @@ pub fn test_unary_op<
4634
fn_scalar: FnScalar,
4735
fn_vector: FnVector,
4836
) {
49-
let mut expected = T::default();
37+
let expected = T::from_fn(|i| fn_scalar(a.as_array()[i]));
38+
// ensure that the elements got put in the right place
5039
for i in 0..N {
51-
expected.as_mut_array()[i] = fn_scalar(a.as_array()[i]);
40+
assert_eq!(expected.as_array()[i], fn_scalar(a.as_array()[i]));
5241
}
5342

5443
let actual = fn_vector(a);
5544

5645
// assert equality for manually calculated result
5746
assert_eq!(expected, actual, "scalar={:?} vector={:?}", expected, actual);
58-
59-
// assert equality using the unary_op method as well
60-
assert_eq!(
61-
expected,
62-
a.unary_op(fn_scalar),
63-
"scalar={:?} unary_op={:?}",
64-
expected,
65-
actual
66-
);
6747
}

tests/all_tests/t_usefulness.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ fn branch_free_divide(numerator: u32x8, magic: u32x8, shift: u32x8) -> u32x8 {
394394
// Returns 32 high bits of the 64 bit result of multiplication of two u32s
395395
let mul_hi = |a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32;
396396

397-
let q = numerator.binary_op(magic, mul_hi);
397+
let q =
398+
u32x8::from_fn(|i| mul_hi(numerator.as_array()[i], magic.as_array()[i]));
398399
let t = ((numerator - q) >> 1) + q;
399400
t >> shift
400401
}

0 commit comments

Comments
 (0)