From 21967e9514011c28fd6fc0164ad6f93117613778 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Mon, 10 Nov 2025 23:30:25 -0500 Subject: [PATCH 01/11] Improve test coverage a bit Catches a silly bug in the Intel simd_eq implementation. --- fearless_simd_tests/tests/harness/mod.rs | 1011 +++++++++++++++++++++- 1 file changed, 1005 insertions(+), 6 deletions(-) diff --git a/fearless_simd_tests/tests/harness/mod.rs b/fearless_simd_tests/tests/harness/mod.rs index 58e669a2..3e0df758 100644 --- a/fearless_simd_tests/tests/harness/mod.rs +++ b/fearless_simd_tests/tests/harness/mod.rs @@ -15,6 +15,38 @@ use fearless_simd::*; use fearless_simd_dev_macros::simd_test; +#[simd_test] +fn splat_f32x4(simd: S) { + let a = f32x4::splat(simd, 4.2); + assert_eq!(a.val, [4.2, 4.2, 4.2, 4.2]); +} + +#[simd_test] +fn abs_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[-1.0, 2.0, -3.0, 4.0]); + assert_eq!(a.abs().val, [1.0, 2.0, 3.0, 4.0]); +} + +#[simd_test] +fn neg_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[1.0, -2.0, 3.0, -4.0]); + assert_eq!((-a).val, [-1.0, 2.0, -3.0, 4.0]); +} + +#[simd_test] +fn add_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f32x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!((a + b).val, [6.0, 8.0, 10.0, 12.0]); +} + +#[simd_test] +fn sub_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[10.0, 20.0, 30.0, 40.0]); + let b = f32x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!((a - b).val, [9.0, 18.0, 27.0, 36.0]); +} + #[simd_test] fn sqrt_f32x4(simd: S) { let a = f32x4::from_slice(simd, &[4.0, 0.0, 1.0, 2.0]); @@ -106,6 +138,14 @@ fn min_precise_f32x4(simd: S) { assert_eq!(a.min_precise(b).val, [1.0, -3.0, 0.0, 0.5]); } +#[simd_test] +fn msub_f32x4(simd: S) { + let a = f32x4::from_slice(simd, &[2.0, 3.0, 4.0, 5.0]); + let b = f32x4::from_slice(simd, &[10.0, 10.0, 10.0, 10.0]); + let c = f32x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(a.msub(b, c).val, [19.0, 28.0, 37.0, 46.0]); +} + #[simd_test] fn max_precise_f32x4_with_nan(simd: S) { let a = f32x4::from_slice(simd, &[f32::NAN, -3.0, f32::INFINITY, 0.5]); @@ -254,6 +294,160 @@ fn not_i8x16(simd: S) { ); } +#[simd_test] +fn add_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, -1, -2, -3, -4, -5, -6, -7, -8], + ); + let b = i8x16::from_slice( + simd, + &[10, 20, 30, 40, 50, 60, 70, 80, 1, 2, 3, 4, 5, 6, 7, 8], + ); + assert_eq!( + (a + b).val, + [11, 22, 33, 44, 55, 66, 77, 88, 0, 0, 0, 0, 0, 0, 0, 0] + ); +} + +#[simd_test] +fn sub_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[10, 20, 30, 40, 50, 60, 70, 80, 0, 0, 0, 0, 0, 0, 0, 0], + ); + let b = i8x16::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!( + (a - b).val, + [ + 9, 18, 27, 36, 45, 54, 63, 72, -1, -2, -3, -4, -5, -6, -7, -8 + ] + ); +} + +#[simd_test] +fn neg_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, + ], + ); + assert_eq!( + (-a).val, + [ + -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, 16 + ] + ); +} + +#[simd_test] +fn simd_eq_i8x16(simd: S) { + let a = i8x16::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]); + let b = i8x16::from_slice(simd, &[1, 0, 3, 0, 5, 0, 7, 0, 1, 0, 3, 0, 5, 0, 7, 0]); + assert_eq!( + a.simd_eq(b).val, + [-1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0] + ); +} + +#[simd_test] +fn simd_lt_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[1, 2, 3, 4, -1, -2, -3, -4, 10, 20, 30, 40, 50, 60, 70, 80], + ); + let b = i8x16::from_slice( + simd, + &[2, 2, 2, 5, 0, 0, 0, 0, 5, 25, 25, 45, 45, 65, 65, 85], + ); + assert_eq!( + a.simd_lt(b).val, + [-1, 0, 0, -1, -1, -1, -1, -1, 0, -1, 0, -1, 0, -1, 0, -1] + ); +} + +#[simd_test] +fn simd_gt_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[2, 2, 2, 5, 0, 0, 0, 0, 5, 25, 25, 45, 45, 65, 65, 85], + ); + let b = i8x16::from_slice( + simd, + &[1, 2, 3, 4, -1, -2, -3, -4, 10, 20, 30, 40, 50, 60, 70, 80], + ); + assert_eq!( + a.simd_gt(b).val, + [-1, 0, 0, -1, -1, -1, -1, -1, 0, -1, 0, -1, 0, -1, 0, -1] + ); +} + +#[simd_test] +fn min_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, + ], + ); + let b = i8x16::from_slice( + simd, + &[ + 2, -1, 4, -3, 6, -5, 8, -7, 10, -9, 12, -11, 14, -13, 16, -15, + ], + ); + assert_eq!( + a.min(b).val, + [ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16 + ] + ); +} + +#[simd_test] +fn max_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, + ], + ); + let b = i8x16::from_slice( + simd, + &[ + 2, -1, 4, -3, 6, -5, 8, -7, 10, -9, 12, -11, 14, -13, 16, -15, + ], + ); + assert_eq!( + a.max(b).val, + [ + 2, -1, 4, -3, 6, -5, 8, -7, 10, -9, 12, -11, 14, -13, 16, -15 + ] + ); +} + +#[simd_test] +fn combine_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = i8x16::from_slice( + simd, + &[ + -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, -1, -2, -3, -4, -5, -6, -7, -8, + -9, -10, -11, -12, -13, -14, -15, -16 + ] + ); +} + #[simd_test] fn and_u8x16(simd: S) { let a = u8x16::from_slice(simd, &[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]); @@ -300,6 +494,111 @@ fn not_u8x16(simd: S) { ); } +#[simd_test] +fn add_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = u8x16::from_slice( + simd, + &[ + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, + ], + ); + assert_eq!( + (a + b).val, + [ + 11, 22, 33, 44, 55, 66, 77, 88, 99, 110, 121, 132, 143, 154, 165, 176 + ] + ); +} + +#[simd_test] +fn sub_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[ + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + ], + ); + let b = u8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + assert_eq!( + (a - b).val, + [ + 99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84 + ] + ); +} + +#[simd_test] +fn min_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[ + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, + ], + ); + let b = u8x16::from_slice( + simd, + &[ + 15, 15, 35, 35, 45, 65, 65, 85, 85, 105, 105, 125, 125, 145, 145, 165, + ], + ); + assert_eq!( + a.min(b).val, + [ + 10, 15, 30, 35, 45, 60, 65, 80, 85, 100, 105, 120, 125, 140, 145, 160 + ] + ); +} + +#[simd_test] +fn max_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[ + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, + ], + ); + let b = u8x16::from_slice( + simd, + &[ + 15, 15, 35, 35, 45, 65, 65, 85, 85, 105, 105, 125, 125, 145, 145, 165, + ], + ); + assert_eq!( + a.max(b).val, + [ + 15, 20, 35, 40, 50, 65, 70, 85, 90, 105, 110, 125, 130, 145, 150, 165 + ] + ); +} + +#[simd_test] +fn combine_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = u8x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32 + ] + ); +} + #[simd_test] fn and_mask8x16(simd: S) { let a = mask8x16::from_slice(simd, &[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]); @@ -605,6 +904,56 @@ fn zip_high_i8x16(simd: S) { ); } +#[simd_test] +fn zip_low_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17, -18, 19, -20, 21, + -22, 23, -24, 25, -26, 27, -28, 29, -30, 31, -32, + ], + ); + let b = i8x32::from_slice( + simd, + &[ + 33, -34, 35, -36, 37, -38, 39, -40, 41, -42, 43, -44, 45, -46, 47, -48, 49, -50, 51, + -52, 53, -54, 55, -56, 57, -58, 59, -60, 61, -62, 63, -64, + ], + ); + assert_eq!( + simd.zip_low_i8x32(a, b).val, + [ + 1, 33, -2, -34, 3, 35, -4, -36, 5, 37, -6, -38, 7, 39, -8, -40, 9, 41, -10, -42, 11, + 43, -12, -44, 13, 45, -14, -46, 15, 47, -16, -48 + ] + ); +} + +#[simd_test] +fn zip_high_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, 17, -18, 19, -20, 21, + -22, 23, -24, 25, -26, 27, -28, 29, -30, 31, -32, + ], + ); + let b = i8x32::from_slice( + simd, + &[ + 33, -34, 35, -36, 37, -38, 39, -40, 41, -42, 43, -44, 45, -46, 47, -48, 49, -50, 51, + -52, 53, -54, 55, -56, 57, -58, 59, -60, 61, -62, 63, -64, + ], + ); + assert_eq!( + simd.zip_high_i8x32(a, b).val, + [ + 17, 49, -18, -50, 19, 51, -20, -52, 21, 53, -22, -54, 23, 55, -24, -56, 25, 57, -26, + -58, 27, 59, -28, -60, 29, 61, -30, -62, 31, 63, -32, -64 + ] + ); +} + #[simd_test] fn zip_low_u8x16(simd: S) { let a = u8x16::from_slice( @@ -642,12 +991,62 @@ fn zip_high_u8x16(simd: S) { } #[simd_test] -fn zip_low_i16x8(simd: S) { - let a = i16x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); - let b = i16x8::from_slice(simd, &[9, -10, 11, -12, 13, -14, 15, -16]); - assert_eq!( - simd.zip_low_i16x8(a, b).val, - [1, 9, -2, -10, 3, 11, -4, -12] +fn zip_low_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + let b = u8x32::from_slice( + simd, + &[ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + ], + ); + assert_eq!( + simd.zip_low_u8x32(a, b).val, + [ + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, 8, 40, 9, 41, 10, 42, 11, 43, + 12, 44, 13, 45, 14, 46, 15, 47 + ] + ); +} + +#[simd_test] +fn zip_high_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + let b = u8x32::from_slice( + simd, + &[ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + ], + ); + assert_eq!( + simd.zip_high_u8x32(a, b).val, + [ + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, 24, 56, 25, 57, 26, 58, + 27, 59, 28, 60, 29, 61, 30, 62, 31, 63 + ] + ); +} + +#[simd_test] +fn zip_low_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + let b = i16x8::from_slice(simd, &[9, -10, 11, -12, 13, -14, 15, -16]); + assert_eq!( + simd.zip_low_i16x8(a, b).val, + [1, 9, -2, -10, 3, 11, -4, -12] ); } @@ -661,6 +1060,50 @@ fn zip_high_i16x8(simd: S) { ); } +#[simd_test] +fn zip_low_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, + ], + ); + let b = i16x16::from_slice( + simd, + &[ + 17, -18, 19, -20, 21, -22, 23, -24, 25, -26, 27, -28, 29, -30, 31, -32, + ], + ); + assert_eq!( + simd.zip_low_i16x16(a, b).val, + [ + 1, 17, -2, -18, 3, 19, -4, -20, 5, 21, -6, -22, 7, 23, -8, -24 + ] + ); +} + +#[simd_test] +fn zip_high_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[ + 1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12, 13, -14, 15, -16, + ], + ); + let b = i16x16::from_slice( + simd, + &[ + 17, -18, 19, -20, 21, -22, 23, -24, 25, -26, 27, -28, 29, -30, 31, -32, + ], + ); + assert_eq!( + simd.zip_high_i16x16(a, b).val, + [ + 9, 25, -10, -26, 11, 27, -12, -28, 13, 29, -14, -30, 15, 31, -16, -32 + ] + ); +} + #[simd_test] fn zip_low_u16x8(simd: S) { let a = u16x8::from_slice(simd, &[0, 1, 2, 3, 4, 5, 6, 7]); @@ -675,6 +1118,42 @@ fn zip_high_u16x8(simd: S) { assert_eq!(simd.zip_high_u16x8(a, b).val, [4, 12, 5, 13, 6, 14, 7, 15]); } +#[simd_test] +fn zip_low_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + ); + let b = u16x16::from_slice( + simd, + &[ + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + assert_eq!( + simd.zip_low_u16x16(a, b).val, + [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23] + ); +} + +#[simd_test] +fn zip_high_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + ); + let b = u16x16::from_slice( + simd, + &[ + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + assert_eq!( + simd.zip_high_u16x16(a, b).val, + [8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] + ); +} + #[simd_test] fn zip_low_i32x4(simd: S) { let a = i32x4::from_slice(simd, &[1, -2, 3, -4]); @@ -689,6 +1168,26 @@ fn zip_high_i32x4(simd: S) { assert_eq!(simd.zip_high_i32x4(a, b).val, [3, 7, -4, -8]); } +#[simd_test] +fn zip_low_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + let b = i32x8::from_slice(simd, &[9, -10, 11, -12, 13, -14, 15, -16]); + assert_eq!( + simd.zip_low_i32x8(a, b).val, + [1, 9, -2, -10, 3, 11, -4, -12] + ); +} + +#[simd_test] +fn zip_high_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + let b = i32x8::from_slice(simd, &[9, -10, 11, -12, 13, -14, 15, -16]); + assert_eq!( + simd.zip_high_i32x8(a, b).val, + [5, 13, -6, -14, 7, 15, -8, -16] + ); +} + #[simd_test] fn zip_low_u32x4(simd: S) { let a = u32x4::from_slice(simd, &[0, 1, 2, 3]); @@ -703,6 +1202,34 @@ fn zip_high_u32x4(simd: S) { assert_eq!(simd.zip_high_u32x4(a, b).val, [2, 6, 3, 7]); } +#[simd_test] +fn zip_low_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[0, 1, 2, 3, 4, 5, 6, 7]); + let b = u32x8::from_slice(simd, &[8, 9, 10, 11, 12, 13, 14, 15]); + assert_eq!(simd.zip_low_u32x8(a, b).val, [0, 8, 1, 9, 2, 10, 3, 11]); +} + +#[simd_test] +fn zip_high_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[0, 1, 2, 3, 4, 5, 6, 7]); + let b = u32x8::from_slice(simd, &[8, 9, 10, 11, 12, 13, 14, 15]); + assert_eq!(simd.zip_high_u32x8(a, b).val, [4, 12, 5, 13, 6, 14, 7, 15]); +} + +#[simd_test] +fn zip_low_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f64x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!(simd.zip_low_f64x4(a, b).val, [1.0, 5.0, 2.0, 6.0]); +} + +#[simd_test] +fn zip_high_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f64x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!(simd.zip_high_f64x4(a, b).val, [3.0, 7.0, 4.0, 8.0]); +} + #[simd_test] fn unzip_low_f32x4(simd: S) { let a = f32x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); @@ -773,6 +1300,56 @@ fn unzip_high_i8x16(simd: S) { ); } +#[simd_test] +fn unzip_low_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let b = i8x32::from_slice( + simd, + &[ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + ], + ); + assert_eq!( + simd.unzip_low_i8x32(a, b).val, + [ + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, + 47, 49, 51, 53, 55, 57, 59, 61, 63 + ] + ); +} + +#[simd_test] +fn unzip_high_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let b = i8x32::from_slice( + simd, + &[ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + ], + ); + assert_eq!( + simd.unzip_high_i8x32(a, b).val, + [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, + 48, 50, 52, 54, 56, 58, 60, 62, 64 + ] + ); +} + #[simd_test] fn unzip_low_u8x16(simd: S) { let a = u8x16::from_slice( @@ -809,6 +1386,56 @@ fn unzip_high_u8x16(simd: S) { ); } +#[simd_test] +fn unzip_low_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let b = u8x32::from_slice( + simd, + &[ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + ], + ); + assert_eq!( + simd.unzip_low_u8x32(a, b).val, + [ + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, + 47, 49, 51, 53, 55, 57, 59, 61, 63 + ] + ); +} + +#[simd_test] +fn unzip_high_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let b = u8x32::from_slice( + simd, + &[ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + ], + ); + assert_eq!( + simd.unzip_high_u8x32(a, b).val, + [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, + 48, 50, 52, 54, 56, 58, 60, 62, 64 + ] + ); +} + #[simd_test] fn unzip_low_i16x8(simd: S) { let a = i16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); @@ -843,6 +1470,78 @@ fn unzip_high_u16x8(simd: S) { ); } +#[simd_test] +fn unzip_low_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = i16x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + simd.unzip_low_i16x16(a, b).val, + [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31] + ); +} + +#[simd_test] +fn unzip_high_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = i16x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + simd.unzip_high_i16x16(a, b).val, + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32] + ); +} + +#[simd_test] +fn unzip_low_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = u16x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + simd.unzip_low_u16x16(a, b).val, + [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31] + ); +} + +#[simd_test] +fn unzip_high_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = u16x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + simd.unzip_high_u16x16(a, b).val, + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32] + ); +} + #[simd_test] fn unzip_low_i32x4(simd: S) { let a = i32x4::from_slice(simd, &[1, 2, 3, 4]); @@ -871,6 +1570,40 @@ fn unzip_high_u32x4(simd: S) { assert_eq!(simd.unzip_high_u32x4(a, b).val, [2, 4, 6, 8]); } +#[simd_test] +fn unzip_low_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i32x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!(simd.unzip_low_i32x8(a, b).val, [1, 3, 5, 7, 9, 11, 13, 15]); +} + +#[simd_test] +fn unzip_high_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i32x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!( + simd.unzip_high_i32x8(a, b).val, + [2, 4, 6, 8, 10, 12, 14, 16] + ); +} + +#[simd_test] +fn unzip_low_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = u32x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!(simd.unzip_low_u32x8(a, b).val, [1, 3, 5, 7, 9, 11, 13, 15]); +} + +#[simd_test] +fn unzip_high_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = u32x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!( + simd.unzip_high_u32x8(a, b).val, + [2, 4, 6, 8, 10, 12, 14, 16] + ); +} + #[simd_test] fn unzip_low_f64x2(simd: S) { let a = f64x2::from_slice(simd, &[1.0, 2.0]); @@ -885,6 +1618,20 @@ fn unzip_high_f64x2(simd: S) { assert_eq!(simd.unzip_high_f64x2(a, b).val, [2.0, 4.0]); } +#[simd_test] +fn unzip_low_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f64x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!(simd.unzip_low_f64x4(a, b).val, [1.0, 3.0, 5.0, 7.0]); +} + +#[simd_test] +fn unzip_high_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f64x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!(simd.unzip_high_f64x4(a, b).val, [2.0, 4.0, 6.0, 8.0]); +} + #[simd_test] fn shr_i8x16(simd: S) { let a = i8x16::from_slice( @@ -969,6 +1716,242 @@ fn shl_u32x4(simd: S) { assert_eq!((a << 4).val, [0xFFFFFFF0, 0xFFFF0, 0xFF0, 0]); } +#[simd_test] +fn add_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i16x8::from_slice(simd, &[10, 20, 30, 40, 50, 60, 70, 80]); + assert_eq!((a + b).val, [11, 22, 33, 44, 55, 66, 77, 88]); +} + +#[simd_test] +fn sub_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[100, 200, 300, 400, 500, 600, 700, 800]); + let b = i16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!((a - b).val, [99, 198, 297, 396, 495, 594, 693, 792]); +} + +#[simd_test] +fn neg_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + assert_eq!((-a).val, [-1, 2, -3, 4, -5, 6, -7, 8]); +} + +#[simd_test] +fn simd_eq_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i16x8::from_slice(simd, &[1, 0, 3, 0, 5, 0, 7, 0]); + assert_eq!(a.simd_eq(b).val, [-1, 0, -1, 0, -1, 0, -1, 0]); +} + +#[simd_test] +fn simd_lt_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, 2, 3, 4, -1, -2, -3, -4]); + let b = i16x8::from_slice(simd, &[2, 2, 2, 5, 0, 0, 0, 0]); + assert_eq!(a.simd_lt(b).val, [-1, 0, 0, -1, -1, -1, -1, -1]); +} + +#[simd_test] +fn simd_gt_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[2, 2, 2, 5, 0, 0, 0, 0]); + let b = i16x8::from_slice(simd, &[1, 2, 3, 4, -1, -2, -3, -4]); + assert_eq!(a.simd_gt(b).val, [-1, 0, 0, -1, -1, -1, -1, -1]); +} + +#[simd_test] +fn min_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + let b = i16x8::from_slice(simd, &[2, -1, 4, -3, 6, -5, 8, -7]); + assert_eq!(a.min(b).val, [1, -2, 3, -4, 5, -6, 7, -8]); +} + +#[simd_test] +fn max_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, -2, 3, -4, 5, -6, 7, -8]); + let b = i16x8::from_slice(simd, &[2, -1, 4, -3, 6, -5, 8, -7]); + assert_eq!(a.max(b).val, [2, -1, 4, -3, 6, -5, 8, -7]); +} + +#[simd_test] +fn combine_i16x8(simd: S) { + let a = i16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i16x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!( + a.combine(b).val, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); +} + +#[simd_test] +fn add_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = u16x8::from_slice(simd, &[10, 20, 30, 40, 50, 60, 70, 80]); + assert_eq!((a + b).val, [11, 22, 33, 44, 55, 66, 77, 88]); +} + +#[simd_test] +fn sub_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[100, 200, 300, 400, 500, 600, 700, 800]); + let b = u16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!((a - b).val, [99, 198, 297, 396, 495, 594, 693, 792]); +} + +#[simd_test] +fn simd_eq_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[1, 2, 32768, 40000, 65535, 6, 7, 8]); + let b = u16x8::from_slice(simd, &[1, 0, 32768, 0, 65535, 0, 7, 0]); + assert_eq!(a.simd_eq(b).val, [-1, 0, -1, 0, -1, 0, -1, 0]); +} + +#[simd_test] +fn simd_lt_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[1, 2, 3, 4, 100, 200, 300, 400]); + let b = u16x8::from_slice(simd, &[2, 2, 2, 5, 40000, 150, 50000, 350]); + assert_eq!(a.simd_lt(b).val, [-1, 0, 0, -1, -1, 0, -1, 0]); +} + +#[simd_test] +fn simd_gt_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[2, 2, 2, 5, 40000, 150, 50000, 350]); + let b = u16x8::from_slice(simd, &[1, 2, 3, 4, 100, 200, 300, 400]); + assert_eq!(a.simd_gt(b).val, [-1, 0, 0, -1, -1, 0, -1, 0]); +} + +#[simd_test] +fn min_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[10, 20, 30, 40, 50, 60, 70, 80]); + let b = u16x8::from_slice(simd, &[15, 15, 35, 35, 45, 65, 65, 85]); + assert_eq!(a.min(b).val, [10, 15, 30, 35, 45, 60, 65, 80]); +} + +#[simd_test] +fn max_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[10, 20, 30, 40, 50, 60, 70, 80]); + let b = u16x8::from_slice(simd, &[15, 15, 35, 35, 45, 65, 65, 85]); + assert_eq!(a.max(b).val, [15, 20, 35, 40, 50, 65, 70, 85]); +} + +#[simd_test] +fn combine_u16x8(simd: S) { + let a = u16x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = u16x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!( + a.combine(b).val, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); +} + +#[simd_test] +fn add_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, 2, 3, 4]); + let b = i32x4::from_slice(simd, &[10, 20, 30, 40]); + assert_eq!((a + b).val, [11, 22, 33, 44]); +} + +#[simd_test] +fn sub_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[100, 200, 300, 400]); + let b = i32x4::from_slice(simd, &[1, 2, 3, 4]); + assert_eq!((a - b).val, [99, 198, 297, 396]); +} + +#[simd_test] +fn simd_eq_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, 2, 3, 4]); + let b = i32x4::from_slice(simd, &[1, 0, 3, 0]); + assert_eq!(a.simd_eq(b).val, [-1, 0, -1, 0]); +} + +#[simd_test] +fn simd_lt_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, 2, -3, -4]); + let b = i32x4::from_slice(simd, &[2, 2, 0, 0]); + assert_eq!(a.simd_lt(b).val, [-1, 0, -1, -1]); +} + +#[simd_test] +fn simd_gt_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[2, 2, 0, 0]); + let b = i32x4::from_slice(simd, &[1, 2, -3, -4]); + assert_eq!(a.simd_gt(b).val, [-1, 0, -1, -1]); +} + +#[simd_test] +fn min_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, -2, 3, -4]); + let b = i32x4::from_slice(simd, &[2, -1, 4, -3]); + assert_eq!(a.min(b).val, [1, -2, 3, -4]); +} + +#[simd_test] +fn max_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, -2, 3, -4]); + let b = i32x4::from_slice(simd, &[2, -1, 4, -3]); + assert_eq!(a.max(b).val, [2, -1, 4, -3]); +} + +#[simd_test] +fn combine_i32x4(simd: S) { + let a = i32x4::from_slice(simd, &[1, 2, 3, 4]); + let b = i32x4::from_slice(simd, &[5, 6, 7, 8]); + assert_eq!(a.combine(b).val, [1, 2, 3, 4, 5, 6, 7, 8]); +} + +#[simd_test] +fn add_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[1, 2, 3, 4]); + let b = u32x4::from_slice(simd, &[10, 20, 30, 40]); + assert_eq!((a + b).val, [11, 22, 33, 44]); +} + +#[simd_test] +fn sub_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[100, 200, 300, 400]); + let b = u32x4::from_slice(simd, &[1, 2, 3, 4]); + assert_eq!((a - b).val, [99, 198, 297, 396]); +} + +#[simd_test] +fn simd_eq_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[1, 2, 2147483648, 4294967295]); + let b = u32x4::from_slice(simd, &[1, 0, 2147483648, 0]); + assert_eq!(a.simd_eq(b).val, [-1, 0, -1, 0]); +} + +#[simd_test] +fn simd_lt_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[1, 2, 100, 200]); + let b = u32x4::from_slice(simd, &[2, 2, 3000000000, 150]); + assert_eq!(a.simd_lt(b).val, [-1, 0, -1, 0]); +} + +#[simd_test] +fn simd_gt_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[2, 2, 3000000000, 150]); + let b = u32x4::from_slice(simd, &[1, 2, 100, 200]); + assert_eq!(a.simd_gt(b).val, [-1, 0, -1, 0]); +} + +#[simd_test] +fn min_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[10, 20, 30, 40]); + let b = u32x4::from_slice(simd, &[15, 15, 35, 35]); + assert_eq!(a.min(b).val, [10, 15, 30, 35]); +} + +#[simd_test] +fn max_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[10, 20, 30, 40]); + let b = u32x4::from_slice(simd, &[15, 15, 35, 35]); + assert_eq!(a.max(b).val, [15, 20, 35, 40]); +} + +#[simd_test] +fn combine_u32x4(simd: S) { + let a = u32x4::from_slice(simd, &[1, 2, 3, 4]); + let b = u32x4::from_slice(simd, &[5, 6, 7, 8]); + assert_eq!(a.combine(b).val, [1, 2, 3, 4, 5, 6, 7, 8]); +} + #[simd_test] fn select_f32x4(simd: S) { let mask = mask32x4::from_slice(simd, &[-1, 0, -1, 0]); @@ -1228,6 +2211,22 @@ fn cvt_i32_f32x4(simd: S) { assert_eq!(a.cvt_i32().val, [-10, 0, 13, 234234]); } +#[simd_test] +fn simd_eq_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[1, 2, 128, 200, 255, 6, 7, 8, 1, 2, 128, 200, 255, 6, 7, 8], + ); + let b = u8x16::from_slice( + simd, + &[1, 0, 128, 0, 255, 0, 7, 0, 1, 0, 128, 0, 255, 0, 7, 0], + ); + assert_eq!( + a.simd_eq(b).val, + [-1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0] + ); +} + #[simd_test] fn simd_ge_u8x16(simd: S) { let vals = u8x16::from_slice( From 210a712e480b1a698b6c92ec7a76baa401836b63 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 18:13:52 -0500 Subject: [PATCH 02/11] Implement AVX2 and improve x86 codegen --- fearless_simd/src/generated/avx2.rs | 1539 +++++++++------------- fearless_simd/src/generated/sse4_2.rs | 149 +-- fearless_simd_gen/src/arch/x86_common.rs | 8 +- fearless_simd_gen/src/mk_avx2.rs | 61 +- fearless_simd_gen/src/mk_sse4_2.rs | 403 ++++-- fearless_simd_gen/src/x86_common.rs | 35 +- 6 files changed, 1047 insertions(+), 1148 deletions(-) diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 6d98d021..6ed464cd 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -104,23 +104,23 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_eq_f32x4(self, a: f32x4, b: f32x4) -> mask32x4 { - unsafe { _mm_castps_si128(_mm_cmpeq_ps(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castps_si128(_mm_cmp_ps::<0i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_lt_f32x4(self, a: f32x4, b: f32x4) -> mask32x4 { - unsafe { _mm_castps_si128(_mm_cmplt_ps(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castps_si128(_mm_cmp_ps::<17i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_le_f32x4(self, a: f32x4, b: f32x4) -> mask32x4 { - unsafe { _mm_castps_si128(_mm_cmple_ps(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castps_si128(_mm_cmp_ps::<18i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_ge_f32x4(self, a: f32x4, b: f32x4) -> mask32x4 { - unsafe { _mm_castps_si128(_mm_cmpge_ps(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castps_si128(_mm_cmp_ps::<29i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_gt_f32x4(self, a: f32x4, b: f32x4) -> mask32x4 { - unsafe { _mm_castps_si128(_mm_cmpgt_ps(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castps_si128(_mm_cmp_ps::<30i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn zip_low_f32x4(self, a: f32x4, b: f32x4) -> f32x4 { @@ -176,10 +176,7 @@ impl Simd for Avx2 { } #[inline(always)] fn select_f32x4(self, a: mask32x4, b: f32x4, c: f32x4) -> f32x4 { - unsafe { - let mask = _mm_castsi128_ps(a.into()); - _mm_or_ps(_mm_and_ps(mask, b.into()), _mm_andnot_ps(mask, c.into())).simd_into(self) - } + unsafe { _mm_blendv_ps(c.into(), b.into(), _mm_castsi128_ps(a.into())).simd_into(self) } } #[inline(always)] fn combine_f32x4(self, a: f32x4, b: f32x4) -> f32x8 { @@ -263,8 +260,8 @@ impl Simd for Avx2 { unsafe { let val = a.into(); let shift_count = _mm_cvtsi32_si128(shift as i32); - let lo_16 = _mm_unpacklo_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); - let hi_16 = _mm_unpackhi_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); + let lo_16 = _mm_unpacklo_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); + let hi_16 = _mm_unpackhi_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); let lo_shifted = _mm_sra_epi16(lo_16, shift_count); let hi_shifted = _mm_sra_epi16(hi_16, shift_count); _mm_packs_epi16(lo_shifted, hi_shifted).simd_into(self) @@ -279,8 +276,8 @@ impl Simd for Avx2 { unsafe { let val = a.into(); let shift_count = _mm_cvtsi32_si128(shift as i32); - let lo_16 = _mm_unpacklo_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); - let hi_16 = _mm_unpackhi_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); + let lo_16 = _mm_unpacklo_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); + let hi_16 = _mm_unpackhi_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); let lo_shifted = _mm_sll_epi16(lo_16, shift_count); let hi_shifted = _mm_sll_epi16(hi_16, shift_count); _mm_packs_epi16(lo_shifted, hi_shifted).simd_into(self) @@ -292,7 +289,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_lt_i8x16(self, a: i8x16, b: i8x16) -> mask8x16 { - unsafe { _mm_cmplt_epi8(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi8(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i8x16(self, a: i8x16, b: i8x16) -> mask8x16 { @@ -317,7 +314,7 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_low_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { unsafe { - let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -326,21 +323,15 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_high_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { unsafe { - let mask = _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_i8x16(self, a: mask8x16, b: i8x16, c: i8x16) -> i8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { @@ -437,12 +428,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_eq_u8x16(self, a: u8x16, b: u8x16) -> mask8x16 { - unsafe { - let sign_bit = _mm_set1_epi8(0x80u8 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi8(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u8x16(self, a: u8x16, b: u8x16) -> mask8x16 { @@ -481,7 +467,7 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_low_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { unsafe { - let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -490,21 +476,15 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_high_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { unsafe { - let mask = _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_u8x16(self, a: mask8x16, b: u8x16, c: u8x16) -> u8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { @@ -564,13 +544,7 @@ impl Simd for Avx2 { b: mask8x16, c: mask8x16, ) -> mask8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask8x16(self, a: mask8x16, b: mask8x16) -> mask8x16 { @@ -633,7 +607,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_lt_i16x8(self, a: i16x8, b: i16x8) -> mask16x8 { - unsafe { _mm_cmplt_epi16(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi16(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i16x8(self, a: i16x8, b: i16x8) -> mask16x8 { @@ -658,7 +632,7 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_low_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { unsafe { - let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -667,21 +641,15 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_high_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { unsafe { - let mask = _mm_setr_epi8(2, 3, 6, 7, 10, 11, 14, 15, 2, 3, 6, 7, 10, 11, 14, 15); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_i16x8(self, a: mask16x8, b: i16x8, c: i16x8) -> i16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { @@ -762,12 +730,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_eq_u16x8(self, a: u16x8, b: u16x8) -> mask16x8 { - unsafe { - let sign_bit = _mm_set1_epi16(0x8000u16 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi16(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u16x8(self, a: u16x8, b: u16x8) -> mask16x8 { @@ -806,7 +769,7 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_low_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { unsafe { - let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -815,21 +778,15 @@ impl Simd for Avx2 { #[inline(always)] fn unzip_high_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { unsafe { - let mask = _mm_setr_epi8(2, 3, 6, 7, 10, 11, 14, 15, 2, 3, 6, 7, 10, 11, 14, 15); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_u16x8(self, a: mask16x8, b: u16x8, c: u16x8) -> u16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { @@ -887,13 +844,7 @@ impl Simd for Avx2 { b: mask16x8, c: mask16x8, ) -> mask16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask16x8(self, a: mask16x8, b: mask16x8) -> mask16x8 { @@ -956,7 +907,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_lt_i32x4(self, a: i32x4, b: i32x4) -> mask32x4 { - unsafe { _mm_cmplt_epi32(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi32(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i32x4(self, a: i32x4, b: i32x4) -> mask32x4 { @@ -996,13 +947,7 @@ impl Simd for Avx2 { } #[inline(always)] fn select_i32x4(self, a: mask32x4, b: i32x4, c: i32x4) -> i32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i32x4(self, a: i32x4, b: i32x4) -> i32x4 { @@ -1087,12 +1032,7 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_eq_u32x4(self, a: u32x4, b: u32x4) -> mask32x4 { - unsafe { - let sign_bit = _mm_set1_epi32(0x80000000u32 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi32(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u32x4(self, a: u32x4, b: u32x4) -> mask32x4 { @@ -1146,13 +1086,7 @@ impl Simd for Avx2 { } #[inline(always)] fn select_u32x4(self, a: mask32x4, b: u32x4, c: u32x4) -> u32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u32x4(self, a: u32x4, b: u32x4) -> u32x4 { @@ -1207,13 +1141,7 @@ impl Simd for Avx2 { b: mask32x4, c: mask32x4, ) -> mask32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask32x4(self, a: mask32x4, b: mask32x4) -> mask32x4 { @@ -1267,23 +1195,23 @@ impl Simd for Avx2 { } #[inline(always)] fn simd_eq_f64x2(self, a: f64x2, b: f64x2) -> mask64x2 { - unsafe { _mm_castpd_si128(_mm_cmpeq_pd(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castpd_si128(_mm_cmp_pd::<0i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_lt_f64x2(self, a: f64x2, b: f64x2) -> mask64x2 { - unsafe { _mm_castpd_si128(_mm_cmplt_pd(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castpd_si128(_mm_cmp_pd::<17i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_le_f64x2(self, a: f64x2, b: f64x2) -> mask64x2 { - unsafe { _mm_castpd_si128(_mm_cmple_pd(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castpd_si128(_mm_cmp_pd::<18i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_ge_f64x2(self, a: f64x2, b: f64x2) -> mask64x2 { - unsafe { _mm_castpd_si128(_mm_cmpge_pd(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castpd_si128(_mm_cmp_pd::<29i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_gt_f64x2(self, a: f64x2, b: f64x2) -> mask64x2 { - unsafe { _mm_castpd_si128(_mm_cmpgt_pd(a.into(), b.into())).simd_into(self) } + unsafe { _mm_castpd_si128(_mm_cmp_pd::<30i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn zip_low_f64x2(self, a: f64x2, b: f64x2) -> f64x2 { @@ -1339,10 +1267,7 @@ impl Simd for Avx2 { } #[inline(always)] fn select_f64x2(self, a: mask64x2, b: f64x2, c: f64x2) -> f64x2 { - unsafe { - let mask = _mm_castsi128_pd(a.into()); - _mm_or_pd(_mm_and_pd(mask, b.into()), _mm_andnot_pd(mask, c.into())).simd_into(self) - } + unsafe { _mm_blendv_pd(c.into(), b.into(), _mm_castsi128_pd(a.into())).simd_into(self) } } #[inline(always)] fn combine_f64x2(self, a: f64x2, b: f64x2) -> f64x4 { @@ -1385,13 +1310,7 @@ impl Simd for Avx2 { b: mask64x2, c: mask64x2, ) -> mask64x2 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask64x2(self, a: mask64x2, b: mask64x2) -> mask64x2 { @@ -1405,174 +1324,141 @@ impl Simd for Avx2 { result.simd_into(self) } #[inline(always)] - fn splat_f32x8(self, a: f32) -> f32x8 { - let half = self.splat_f32x4(a); - self.combine_f32x4(half, half) + fn splat_f32x8(self, val: f32) -> f32x8 { + unsafe { _mm256_set1_ps(val).simd_into(self) } } #[inline(always)] fn abs_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.abs_f32x4(a0), self.abs_f32x4(a1)) + unsafe { _mm256_andnot_ps(_mm256_set1_ps(-0.0), a.into()).simd_into(self) } } #[inline(always)] fn neg_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.neg_f32x4(a0), self.neg_f32x4(a1)) + unsafe { _mm256_xor_ps(a.into(), _mm256_set1_ps(-0.0)).simd_into(self) } } #[inline(always)] fn sqrt_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.sqrt_f32x4(a0), self.sqrt_f32x4(a1)) + unsafe { _mm256_sqrt_ps(a.into()).simd_into(self) } } #[inline(always)] fn add_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.add_f32x4(a0, b0), self.add_f32x4(a1, b1)) + unsafe { _mm256_add_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.sub_f32x4(a0, b0), self.sub_f32x4(a1, b1)) + unsafe { _mm256_sub_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.mul_f32x4(a0, b0), self.mul_f32x4(a1, b1)) + unsafe { _mm256_mul_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn div_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.div_f32x4(a0, b0), self.div_f32x4(a1, b1)) + unsafe { _mm256_div_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn copysign_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.copysign_f32x4(a0, b0), self.copysign_f32x4(a1, b1)) + unsafe { + let mask = _mm256_set1_ps(-0.0); + _mm256_or_ps( + _mm256_and_ps(mask, b.into()), + _mm256_andnot_ps(mask, a.into()), + ) + .simd_into(self) + } } #[inline(always)] fn simd_eq_f32x8(self, a: f32x8, b: f32x8) -> mask32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_mask32x4(self.simd_eq_f32x4(a0, b0), self.simd_eq_f32x4(a1, b1)) + unsafe { _mm256_castps_si256(_mm256_cmp_ps::<0i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_lt_f32x8(self, a: f32x8, b: f32x8) -> mask32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_mask32x4(self.simd_lt_f32x4(a0, b0), self.simd_lt_f32x4(a1, b1)) + unsafe { _mm256_castps_si256(_mm256_cmp_ps::<17i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_le_f32x8(self, a: f32x8, b: f32x8) -> mask32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_mask32x4(self.simd_le_f32x4(a0, b0), self.simd_le_f32x4(a1, b1)) + unsafe { _mm256_castps_si256(_mm256_cmp_ps::<18i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_ge_f32x8(self, a: f32x8, b: f32x8) -> mask32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_mask32x4(self.simd_ge_f32x4(a0, b0), self.simd_ge_f32x4(a1, b1)) + unsafe { _mm256_castps_si256(_mm256_cmp_ps::<29i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_gt_f32x8(self, a: f32x8, b: f32x8) -> mask32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_mask32x4(self.simd_gt_f32x4(a0, b0), self.simd_gt_f32x4(a1, b1)) + unsafe { _mm256_castps_si256(_mm256_cmp_ps::<30i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn zip_low_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, _) = self.split_f32x8(a); - let (b0, _) = self.split_f32x8(b); - self.combine_f32x4(self.zip_low_f32x4(a0, b0), self.zip_high_f32x4(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_ps(a.into(), b.into()); + let hi = _mm256_unpackhi_ps(a.into(), b.into()); + _mm256_permute2f128_ps::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (_, a1) = self.split_f32x8(a); - let (_, b1) = self.split_f32x8(b); - self.combine_f32x4(self.zip_low_f32x4(a1, b1), self.zip_high_f32x4(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_ps(a.into(), b.into()); + let hi = _mm256_unpackhi_ps(a.into(), b.into()); + _mm256_permute2f128_ps::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.unzip_low_f32x4(a0, a1), self.unzip_low_f32x4(b0, b1)) + unsafe { + let t1 = _mm256_permutevar8x32_ps(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = _mm256_permutevar8x32_ps(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2f128_ps::<0b0010_0000>(t1, t2).simd_into(self) + } } #[inline(always)] fn unzip_high_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.unzip_high_f32x4(a0, a1), self.unzip_high_f32x4(b0, b1)) + unsafe { + let t1 = _mm256_permutevar8x32_ps(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = _mm256_permutevar8x32_ps(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2f128_ps::<0b0011_0001>(t1, t2).simd_into(self) + } } #[inline(always)] fn max_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.max_f32x4(a0, b0), self.max_f32x4(a1, b1)) + unsafe { _mm256_max_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_precise_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4( - self.max_precise_f32x4(a0, b0), - self.max_precise_f32x4(a1, b1), - ) + unsafe { _mm256_max_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn min_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4(self.min_f32x4(a0, b0), self.min_f32x4(a1, b1)) + unsafe { _mm256_min_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn min_precise_f32x8(self, a: f32x8, b: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - self.combine_f32x4( - self.min_precise_f32x4(a0, b0), - self.min_precise_f32x4(a1, b1), - ) + unsafe { _mm256_min_ps(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn madd_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - let (c0, c1) = self.split_f32x8(c); - self.combine_f32x4(self.madd_f32x4(a0, b0, c0), self.madd_f32x4(a1, b1, c1)) + unsafe { _mm256_fmadd_ps(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] fn msub_f32x8(self, a: f32x8, b: f32x8, c: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - let (b0, b1) = self.split_f32x8(b); - let (c0, c1) = self.split_f32x8(c); - self.combine_f32x4(self.msub_f32x4(a0, b0, c0), self.msub_f32x4(a1, b1, c1)) + unsafe { _mm256_fmsub_ps(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] fn floor_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.floor_f32x4(a0), self.floor_f32x4(a1)) + unsafe { _mm256_floor_ps(a.into()).simd_into(self) } } #[inline(always)] fn fract_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.fract_f32x4(a0), self.fract_f32x4(a1)) + a - a.trunc() } #[inline(always)] fn trunc_f32x8(self, a: f32x8) -> f32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f32x4(self.trunc_f32x4(a0), self.trunc_f32x4(a1)) + unsafe { _mm256_round_ps(a.into(), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC).simd_into(self) } } #[inline(always)] fn select_f32x8(self, a: mask32x8, b: f32x8, c: f32x8) -> f32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_f32x8(b); - let (c0, c1) = self.split_f32x8(c); - self.combine_f32x4(self.select_f32x4(a0, b0, c0), self.select_f32x4(a1, b1, c1)) + unsafe { + _mm256_blendv_ps(c.into(), b.into(), _mm256_castsi256_ps(a.into())).simd_into(self) + } } #[inline(always)] fn combine_f32x8(self, a: f32x8, b: f32x8) -> f32x16 { @@ -1591,177 +1477,185 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f64_f32x8(self, a: f32x8) -> f64x4 { - let (a0, a1) = self.split_f32x8(a); - self.combine_f64x2( - self.reinterpret_f64_f32x4(a0), - self.reinterpret_f64_f32x4(a1), - ) + f64x4 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_i32_f32x8(self, a: f32x8) -> i32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_i32x4( - self.reinterpret_i32_f32x4(a0), - self.reinterpret_i32_f32x4(a1), - ) + i32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u8_f32x8(self, a: f32x8) -> u8x32 { - let (a0, a1) = self.split_f32x8(a); - self.combine_u8x16(self.reinterpret_u8_f32x4(a0), self.reinterpret_u8_f32x4(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u32_f32x8(self, a: f32x8) -> u32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_u32x4( - self.reinterpret_u32_f32x4(a0), - self.reinterpret_u32_f32x4(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn cvt_u32_f32x8(self, a: f32x8) -> u32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_u32x4(self.cvt_u32_f32x4(a0), self.cvt_u32_f32x4(a1)) + unsafe { + _mm256_cvtps_epi32(_mm256_max_ps( + _mm256_floor_ps(a.into()), + _mm256_set1_ps(0.0), + )) + .simd_into(self) + } } #[inline(always)] fn cvt_i32_f32x8(self, a: f32x8) -> i32x8 { - let (a0, a1) = self.split_f32x8(a); - self.combine_i32x4(self.cvt_i32_f32x4(a0), self.cvt_i32_f32x4(a1)) + unsafe { _mm256_cvtps_epi32(a.trunc().into()).simd_into(self) } } #[inline(always)] - fn splat_i8x32(self, a: i8) -> i8x32 { - let half = self.splat_i8x16(a); - self.combine_i8x16(half, half) + fn splat_i8x32(self, val: i8) -> i8x32 { + unsafe { _mm256_set1_epi8(val).simd_into(self) } } #[inline(always)] fn not_i8x32(self, a: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - self.combine_i8x16(self.not_i8x16(a0), self.not_i8x16(a1)) + a ^ !0 } #[inline(always)] fn add_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.add_i8x16(a0, b0), self.add_i8x16(a1, b1)) + unsafe { _mm256_add_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.sub_i8x16(a0, b0), self.sub_i8x16(a1, b1)) + unsafe { _mm256_sub_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.mul_i8x16(a0, b0), self.mul_i8x16(a1, b1)) + todo!() } #[inline(always)] fn and_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.and_i8x16(a0, b0), self.and_i8x16(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.or_i8x16(a0, b0), self.or_i8x16(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.xor_i8x16(a0, b0), self.xor_i8x16(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_i8x32(self, a: i8x32, b: u32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - self.combine_i8x16(self.shr_i8x16(a0, b), self.shr_i8x16(a1, b)) + fn shr_i8x32(self, a: i8x32, shift: u32) -> i8x32 { + unsafe { + let val = a.into(); + let shift_count = _mm_cvtsi32_si128(shift as i32); + let lo_16 = _mm256_unpacklo_epi8(val, _mm256_cmpgt_epi8(_mm256_setzero_si256(), val)); + let hi_16 = _mm256_unpackhi_epi8(val, _mm256_cmpgt_epi8(_mm256_setzero_si256(), val)); + let lo_shifted = _mm256_sra_epi16(lo_16, shift_count); + let hi_shifted = _mm256_sra_epi16(hi_16, shift_count); + _mm256_packs_epi16(lo_shifted, hi_shifted).simd_into(self) + } } #[inline(always)] fn shrv_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.shrv_i8x16(a0, b0), self.shrv_i8x16(a1, b1)) + core::array::from_fn(|i| core::ops::Shr::shr(a.val[i], b.val[i])).simd_into(self) } #[inline(always)] - fn shl_i8x32(self, a: i8x32, b: u32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - self.combine_i8x16(self.shl_i8x16(a0, b), self.shl_i8x16(a1, b)) + fn shl_i8x32(self, a: i8x32, shift: u32) -> i8x32 { + unsafe { + let val = a.into(); + let shift_count = _mm_cvtsi32_si128(shift as i32); + let lo_16 = _mm256_unpacklo_epi8(val, _mm256_cmpgt_epi8(_mm256_setzero_si256(), val)); + let hi_16 = _mm256_unpackhi_epi8(val, _mm256_cmpgt_epi8(_mm256_setzero_si256(), val)); + let lo_shifted = _mm256_sll_epi16(lo_16, shift_count); + let hi_shifted = _mm256_sll_epi16(hi_16, shift_count); + _mm256_packs_epi16(lo_shifted, hi_shifted).simd_into(self) + } } #[inline(always)] fn simd_eq_i8x32(self, a: i8x32, b: i8x32) -> mask8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_mask8x16(self.simd_eq_i8x16(a0, b0), self.simd_eq_i8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_i8x32(self, a: i8x32, b: i8x32) -> mask8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_mask8x16(self.simd_lt_i8x16(a0, b0), self.simd_lt_i8x16(a1, b1)) + unsafe { _mm256_cmpgt_epi8(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i8x32(self, a: i8x32, b: i8x32) -> mask8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_mask8x16(self.simd_le_i8x16(a0, b0), self.simd_le_i8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(_mm256_min_epi8(a.into(), b.into()), a.into()).simd_into(self) } } #[inline(always)] fn simd_ge_i8x32(self, a: i8x32, b: i8x32) -> mask8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_mask8x16(self.simd_ge_i8x16(a0, b0), self.simd_ge_i8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(_mm256_max_epi8(a.into(), b.into()), a.into()).simd_into(self) } } #[inline(always)] fn simd_gt_i8x32(self, a: i8x32, b: i8x32) -> mask8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_mask8x16(self.simd_gt_i8x16(a0, b0), self.simd_gt_i8x16(a1, b1)) + unsafe { _mm256_cmpgt_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn zip_low_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, _) = self.split_i8x32(a); - let (b0, _) = self.split_i8x32(b); - self.combine_i8x16(self.zip_low_i8x16(a0, b0), self.zip_high_i8x16(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi8(a.into(), b.into()); + let hi = _mm256_unpackhi_epi8(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (_, a1) = self.split_i8x32(a); - let (_, b1) = self.split_i8x32(b); - self.combine_i8x16(self.zip_low_i8x16(a1, b1), self.zip_high_i8x16(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi8(a.into(), b.into()); + let hi = _mm256_unpackhi_epi8(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.unzip_low_i8x16(a0, a1), self.unzip_low_i8x16(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, 0, 2, 4, 6, 8, 10, 12, 14, 1, + 3, 5, 7, 9, 11, 13, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0010_0000>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn unzip_high_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.unzip_high_i8x16(a0, a1), self.unzip_high_i8x16(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, 0, 2, 4, 6, 8, 10, 12, 14, 1, + 3, 5, 7, 9, 11, 13, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0011_0001>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn select_i8x32(self, a: mask8x32, b: i8x32, c: i8x32) -> i8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_i8x32(b); - let (c0, c1) = self.split_i8x32(c); - self.combine_i8x16(self.select_i8x16(a0, b0, c0), self.select_i8x16(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.min_i8x16(a0, b0), self.min_i8x16(a1, b1)) + unsafe { _mm256_min_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - let (b0, b1) = self.split_i8x32(b); - self.combine_i8x16(self.max_i8x16(a0, b0), self.max_i8x16(a1, b1)) + unsafe { _mm256_max_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_i8x32(self, a: i8x32, b: i8x32) -> i8x64 { @@ -1780,156 +1674,171 @@ impl Simd for Avx2 { } #[inline(always)] fn neg_i8x32(self, a: i8x32) -> i8x32 { - let (a0, a1) = self.split_i8x32(a); - self.combine_i8x16(self.neg_i8x16(a0), self.neg_i8x16(a1)) + unsafe { _mm256_sub_epi8(_mm256_setzero_si256(), a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_i8x32(self, a: i8x32) -> u8x32 { - let (a0, a1) = self.split_i8x32(a); - self.combine_u8x16(self.reinterpret_u8_i8x16(a0), self.reinterpret_u8_i8x16(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u32_i8x32(self, a: i8x32) -> u32x8 { - let (a0, a1) = self.split_i8x32(a); - self.combine_u32x4( - self.reinterpret_u32_i8x16(a0), - self.reinterpret_u32_i8x16(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] - fn splat_u8x32(self, a: u8) -> u8x32 { - let half = self.splat_u8x16(a); - self.combine_u8x16(half, half) + fn splat_u8x32(self, val: u8) -> u8x32 { + unsafe { _mm256_set1_epi8(val as _).simd_into(self) } } #[inline(always)] fn not_u8x32(self, a: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - self.combine_u8x16(self.not_u8x16(a0), self.not_u8x16(a1)) + a ^ !0 } #[inline(always)] fn add_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.add_u8x16(a0, b0), self.add_u8x16(a1, b1)) + unsafe { _mm256_add_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.sub_u8x16(a0, b0), self.sub_u8x16(a1, b1)) + unsafe { _mm256_sub_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.mul_u8x16(a0, b0), self.mul_u8x16(a1, b1)) + todo!() } #[inline(always)] fn and_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.and_u8x16(a0, b0), self.and_u8x16(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.or_u8x16(a0, b0), self.or_u8x16(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.xor_u8x16(a0, b0), self.xor_u8x16(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_u8x32(self, a: u8x32, b: u32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - self.combine_u8x16(self.shr_u8x16(a0, b), self.shr_u8x16(a1, b)) + fn shr_u8x32(self, a: u8x32, shift: u32) -> u8x32 { + unsafe { + let val = a.into(); + let shift_count = _mm_cvtsi32_si128(shift as i32); + let lo_16 = _mm256_unpacklo_epi8(val, _mm256_setzero_si256()); + let hi_16 = _mm256_unpackhi_epi8(val, _mm256_setzero_si256()); + let lo_shifted = _mm256_srl_epi16(lo_16, shift_count); + let hi_shifted = _mm256_srl_epi16(hi_16, shift_count); + _mm256_packus_epi16(lo_shifted, hi_shifted).simd_into(self) + } } #[inline(always)] fn shrv_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.shrv_u8x16(a0, b0), self.shrv_u8x16(a1, b1)) + core::array::from_fn(|i| core::ops::Shr::shr(a.val[i], b.val[i])).simd_into(self) } #[inline(always)] - fn shl_u8x32(self, a: u8x32, b: u32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - self.combine_u8x16(self.shl_u8x16(a0, b), self.shl_u8x16(a1, b)) + fn shl_u8x32(self, a: u8x32, shift: u32) -> u8x32 { + unsafe { + let val = a.into(); + let shift_count = _mm_cvtsi32_si128(shift as i32); + let lo_16 = _mm256_unpacklo_epi8(val, _mm256_setzero_si256()); + let hi_16 = _mm256_unpackhi_epi8(val, _mm256_setzero_si256()); + let lo_shifted = _mm256_sll_epi16(lo_16, shift_count); + let hi_shifted = _mm256_sll_epi16(hi_16, shift_count); + _mm256_packus_epi16(lo_shifted, hi_shifted).simd_into(self) + } } #[inline(always)] fn simd_eq_u8x32(self, a: u8x32, b: u8x32) -> mask8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_mask8x16(self.simd_eq_u8x16(a0, b0), self.simd_eq_u8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u8x32(self, a: u8x32, b: u8x32) -> mask8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_mask8x16(self.simd_lt_u8x16(a0, b0), self.simd_lt_u8x16(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi8(0x80u8 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi8(b_signed, a_signed).simd_into(self) + } } #[inline(always)] fn simd_le_u8x32(self, a: u8x32, b: u8x32) -> mask8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_mask8x16(self.simd_le_u8x16(a0, b0), self.simd_le_u8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(_mm256_min_epu8(a.into(), b.into()), a.into()).simd_into(self) } } #[inline(always)] fn simd_ge_u8x32(self, a: u8x32, b: u8x32) -> mask8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_mask8x16(self.simd_ge_u8x16(a0, b0), self.simd_ge_u8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(_mm256_max_epu8(a.into(), b.into()), a.into()).simd_into(self) } } #[inline(always)] fn simd_gt_u8x32(self, a: u8x32, b: u8x32) -> mask8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_mask8x16(self.simd_gt_u8x16(a0, b0), self.simd_gt_u8x16(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi8(0x80u8 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi8(a_signed, b_signed).simd_into(self) + } } #[inline(always)] fn zip_low_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, _) = self.split_u8x32(a); - let (b0, _) = self.split_u8x32(b); - self.combine_u8x16(self.zip_low_u8x16(a0, b0), self.zip_high_u8x16(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi8(a.into(), b.into()); + let hi = _mm256_unpackhi_epi8(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (_, a1) = self.split_u8x32(a); - let (_, b1) = self.split_u8x32(b); - self.combine_u8x16(self.zip_low_u8x16(a1, b1), self.zip_high_u8x16(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi8(a.into(), b.into()); + let hi = _mm256_unpackhi_epi8(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.unzip_low_u8x16(a0, a1), self.unzip_low_u8x16(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, 0, 2, 4, 6, 8, 10, 12, 14, 1, + 3, 5, 7, 9, 11, 13, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0010_0000>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn unzip_high_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.unzip_high_u8x16(a0, a1), self.unzip_high_u8x16(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, 0, 2, 4, 6, 8, 10, 12, 14, 1, + 3, 5, 7, 9, 11, 13, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0011_0001>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn select_u8x32(self, a: mask8x32, b: u8x32, c: u8x32) -> u8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_u8x32(b); - let (c0, c1) = self.split_u8x32(c); - self.combine_u8x16(self.select_u8x16(a0, b0, c0), self.select_u8x16(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.min_u8x16(a0, b0), self.min_u8x16(a1, b1)) + unsafe { _mm256_min_epu8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - let (a0, a1) = self.split_u8x32(a); - let (b0, b1) = self.split_u8x32(b); - self.combine_u8x16(self.max_u8x16(a0, b0), self.max_u8x16(a1, b1)) + unsafe { _mm256_max_epu8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_u8x32(self, a: u8x32, b: u8x32) -> u8x64 { @@ -1953,39 +1862,30 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u32_u8x32(self, a: u8x32) -> u32x8 { - let (a0, a1) = self.split_u8x32(a); - self.combine_u32x4( - self.reinterpret_u32_u8x16(a0), - self.reinterpret_u32_u8x16(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] - fn splat_mask8x32(self, a: i8) -> mask8x32 { - let half = self.splat_mask8x16(a); - self.combine_mask8x16(half, half) + fn splat_mask8x32(self, val: i8) -> mask8x32 { + unsafe { _mm256_set1_epi8(val).simd_into(self) } } #[inline(always)] fn not_mask8x32(self, a: mask8x32) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - self.combine_mask8x16(self.not_mask8x16(a0), self.not_mask8x16(a1)) + a ^ !0 } #[inline(always)] fn and_mask8x32(self, a: mask8x32, b: mask8x32) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_mask8x32(b); - self.combine_mask8x16(self.and_mask8x16(a0, b0), self.and_mask8x16(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_mask8x32(self, a: mask8x32, b: mask8x32) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_mask8x32(b); - self.combine_mask8x16(self.or_mask8x16(a0, b0), self.or_mask8x16(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_mask8x32(self, a: mask8x32, b: mask8x32) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_mask8x32(b); - self.combine_mask8x16(self.xor_mask8x16(a0, b0), self.xor_mask8x16(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn select_mask8x32( @@ -1994,19 +1894,11 @@ impl Simd for Avx2 { b: mask8x32, c: mask8x32, ) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_mask8x32(b); - let (c0, c1) = self.split_mask8x32(c); - self.combine_mask8x16( - self.select_mask8x16(a0, b0, c0), - self.select_mask8x16(a1, b1, c1), - ) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask8x32(self, a: mask8x32, b: mask8x32) -> mask8x32 { - let (a0, a1) = self.split_mask8x32(a); - let (b0, b1) = self.split_mask8x32(b); - self.combine_mask8x16(self.simd_eq_mask8x16(a0, b0), self.simd_eq_mask8x16(a1, b1)) + unsafe { _mm256_cmpeq_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_mask8x32(self, a: mask8x32, b: mask8x32) -> mask8x64 { @@ -2024,139 +1916,132 @@ impl Simd for Avx2 { (b0.simd_into(self), b1.simd_into(self)) } #[inline(always)] - fn splat_i16x16(self, a: i16) -> i16x16 { - let half = self.splat_i16x8(a); - self.combine_i16x8(half, half) + fn splat_i16x16(self, val: i16) -> i16x16 { + unsafe { _mm256_set1_epi16(val).simd_into(self) } } #[inline(always)] fn not_i16x16(self, a: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - self.combine_i16x8(self.not_i16x8(a0), self.not_i16x8(a1)) + a ^ !0 } #[inline(always)] fn add_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.add_i16x8(a0, b0), self.add_i16x8(a1, b1)) + unsafe { _mm256_add_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.sub_i16x8(a0, b0), self.sub_i16x8(a1, b1)) + unsafe { _mm256_sub_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.mul_i16x8(a0, b0), self.mul_i16x8(a1, b1)) + unsafe { _mm256_mullo_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn and_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.and_i16x8(a0, b0), self.and_i16x8(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.or_i16x8(a0, b0), self.or_i16x8(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.xor_i16x8(a0, b0), self.xor_i16x8(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_i16x16(self, a: i16x16, b: u32) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - self.combine_i16x8(self.shr_i16x8(a0, b), self.shr_i16x8(a1, b)) + fn shr_i16x16(self, a: i16x16, shift: u32) -> i16x16 { + unsafe { _mm256_sra_epi16(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn shrv_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.shrv_i16x8(a0, b0), self.shrv_i16x8(a1, b1)) + core::array::from_fn(|i| core::ops::Shr::shr(a.val[i], b.val[i])).simd_into(self) } #[inline(always)] - fn shl_i16x16(self, a: i16x16, b: u32) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - self.combine_i16x8(self.shl_i16x8(a0, b), self.shl_i16x8(a1, b)) + fn shl_i16x16(self, a: i16x16, shift: u32) -> i16x16 { + unsafe { _mm256_sll_epi16(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn simd_eq_i16x16(self, a: i16x16, b: i16x16) -> mask16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_mask16x8(self.simd_eq_i16x8(a0, b0), self.simd_eq_i16x8(a1, b1)) + unsafe { _mm256_cmpeq_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_i16x16(self, a: i16x16, b: i16x16) -> mask16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_mask16x8(self.simd_lt_i16x8(a0, b0), self.simd_lt_i16x8(a1, b1)) + unsafe { _mm256_cmpgt_epi16(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i16x16(self, a: i16x16, b: i16x16) -> mask16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_mask16x8(self.simd_le_i16x8(a0, b0), self.simd_le_i16x8(a1, b1)) + unsafe { + _mm256_cmpeq_epi16(_mm256_min_epi16(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_ge_i16x16(self, a: i16x16, b: i16x16) -> mask16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_mask16x8(self.simd_ge_i16x8(a0, b0), self.simd_ge_i16x8(a1, b1)) + unsafe { + _mm256_cmpeq_epi16(_mm256_max_epi16(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_gt_i16x16(self, a: i16x16, b: i16x16) -> mask16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_mask16x8(self.simd_gt_i16x8(a0, b0), self.simd_gt_i16x8(a1, b1)) + unsafe { _mm256_cmpgt_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn zip_low_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, _) = self.split_i16x16(a); - let (b0, _) = self.split_i16x16(b); - self.combine_i16x8(self.zip_low_i16x8(a0, b0), self.zip_high_i16x8(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi16(a.into(), b.into()); + let hi = _mm256_unpackhi_epi16(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (_, a1) = self.split_i16x16(a); - let (_, b1) = self.split_i16x16(b); - self.combine_i16x8(self.zip_low_i16x8(a1, b1), self.zip_high_i16x8(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi16(a.into(), b.into()); + let hi = _mm256_unpackhi_epi16(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.unzip_low_i16x8(a0, a1), self.unzip_low_i16x8(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, + 3, 6, 7, 10, 11, 14, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0010_0000>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn unzip_high_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.unzip_high_i16x8(a0, a1), self.unzip_high_i16x8(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, + 3, 6, 7, 10, 11, 14, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0011_0001>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn select_i16x16(self, a: mask16x16, b: i16x16, c: i16x16) -> i16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_i16x16(b); - let (c0, c1) = self.split_i16x16(c); - self.combine_i16x8(self.select_i16x8(a0, b0, c0), self.select_i16x8(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.min_i16x8(a0, b0), self.min_i16x8(a1, b1)) + unsafe { _mm256_min_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_i16x16(self, a: i16x16, b: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - let (b0, b1) = self.split_i16x16(b); - self.combine_i16x8(self.max_i16x8(a0, b0), self.max_i16x8(a1, b1)) + unsafe { _mm256_max_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_i16x16(self, a: i16x16, b: i16x16) -> i16x32 { @@ -2175,156 +2060,159 @@ impl Simd for Avx2 { } #[inline(always)] fn neg_i16x16(self, a: i16x16) -> i16x16 { - let (a0, a1) = self.split_i16x16(a); - self.combine_i16x8(self.neg_i16x8(a0), self.neg_i16x8(a1)) + unsafe { _mm256_sub_epi16(_mm256_setzero_si256(), a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_i16x16(self, a: i16x16) -> u8x32 { - let (a0, a1) = self.split_i16x16(a); - self.combine_u8x16(self.reinterpret_u8_i16x8(a0), self.reinterpret_u8_i16x8(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u32_i16x16(self, a: i16x16) -> u32x8 { - let (a0, a1) = self.split_i16x16(a); - self.combine_u32x4( - self.reinterpret_u32_i16x8(a0), - self.reinterpret_u32_i16x8(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] - fn splat_u16x16(self, a: u16) -> u16x16 { - let half = self.splat_u16x8(a); - self.combine_u16x8(half, half) + fn splat_u16x16(self, val: u16) -> u16x16 { + unsafe { _mm256_set1_epi16(val as _).simd_into(self) } } #[inline(always)] fn not_u16x16(self, a: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - self.combine_u16x8(self.not_u16x8(a0), self.not_u16x8(a1)) + a ^ !0 } #[inline(always)] fn add_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.add_u16x8(a0, b0), self.add_u16x8(a1, b1)) + unsafe { _mm256_add_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.sub_u16x8(a0, b0), self.sub_u16x8(a1, b1)) + unsafe { _mm256_sub_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.mul_u16x8(a0, b0), self.mul_u16x8(a1, b1)) + unsafe { _mm256_mullo_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn and_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.and_u16x8(a0, b0), self.and_u16x8(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.or_u16x8(a0, b0), self.or_u16x8(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.xor_u16x8(a0, b0), self.xor_u16x8(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_u16x16(self, a: u16x16, b: u32) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - self.combine_u16x8(self.shr_u16x8(a0, b), self.shr_u16x8(a1, b)) + fn shr_u16x16(self, a: u16x16, shift: u32) -> u16x16 { + unsafe { _mm256_srl_epi16(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn shrv_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.shrv_u16x8(a0, b0), self.shrv_u16x8(a1, b1)) + core::array::from_fn(|i| core::ops::Shr::shr(a.val[i], b.val[i])).simd_into(self) } #[inline(always)] - fn shl_u16x16(self, a: u16x16, b: u32) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - self.combine_u16x8(self.shl_u16x8(a0, b), self.shl_u16x8(a1, b)) + fn shl_u16x16(self, a: u16x16, shift: u32) -> u16x16 { + unsafe { _mm256_sll_epi16(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn simd_eq_u16x16(self, a: u16x16, b: u16x16) -> mask16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_mask16x8(self.simd_eq_u16x8(a0, b0), self.simd_eq_u16x8(a1, b1)) + unsafe { _mm256_cmpeq_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u16x16(self, a: u16x16, b: u16x16) -> mask16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_mask16x8(self.simd_lt_u16x8(a0, b0), self.simd_lt_u16x8(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi16(0x8000u16 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi16(b_signed, a_signed).simd_into(self) + } } #[inline(always)] fn simd_le_u16x16(self, a: u16x16, b: u16x16) -> mask16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_mask16x8(self.simd_le_u16x8(a0, b0), self.simd_le_u16x8(a1, b1)) + unsafe { + _mm256_cmpeq_epi16(_mm256_min_epu16(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_ge_u16x16(self, a: u16x16, b: u16x16) -> mask16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_mask16x8(self.simd_ge_u16x8(a0, b0), self.simd_ge_u16x8(a1, b1)) + unsafe { + _mm256_cmpeq_epi16(_mm256_max_epu16(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_gt_u16x16(self, a: u16x16, b: u16x16) -> mask16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_mask16x8(self.simd_gt_u16x8(a0, b0), self.simd_gt_u16x8(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi16(0x8000u16 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi16(a_signed, b_signed).simd_into(self) + } } #[inline(always)] fn zip_low_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, _) = self.split_u16x16(a); - let (b0, _) = self.split_u16x16(b); - self.combine_u16x8(self.zip_low_u16x8(a0, b0), self.zip_high_u16x8(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi16(a.into(), b.into()); + let hi = _mm256_unpackhi_epi16(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (_, a1) = self.split_u16x16(a); - let (_, b1) = self.split_u16x16(b); - self.combine_u16x8(self.zip_low_u16x8(a1, b1), self.zip_high_u16x8(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi16(a.into(), b.into()); + let hi = _mm256_unpackhi_epi16(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.unzip_low_u16x8(a0, a1), self.unzip_low_u16x8(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, + 3, 6, 7, 10, 11, 14, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0010_0000>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn unzip_high_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.unzip_high_u16x8(a0, a1), self.unzip_high_u16x8(b0, b1)) + unsafe { + let mask = _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, + 3, 6, 7, 10, 11, 14, 15, + ); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + let packed = _mm256_permute2x128_si256::<0b0011_0001>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled), + ); + packed.simd_into(self) + } } #[inline(always)] fn select_u16x16(self, a: mask16x16, b: u16x16, c: u16x16) -> u16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_u16x16(b); - let (c0, c1) = self.split_u16x16(c); - self.combine_u16x8(self.select_u16x8(a0, b0, c0), self.select_u16x8(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.min_u16x8(a0, b0), self.min_u16x8(a1, b1)) + unsafe { _mm256_min_epu16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_u16x16(self, a: u16x16, b: u16x16) -> u16x16 { - let (a0, a1) = self.split_u16x16(a); - let (b0, b1) = self.split_u16x16(b); - self.combine_u16x8(self.max_u16x8(a0, b0), self.max_u16x8(a1, b1)) + unsafe { _mm256_max_epu16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_u16x16(self, a: u16x16, b: u16x16) -> u16x32 { @@ -2354,44 +2242,37 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u16x16(self, a: u16x16) -> u8x32 { - let (a0, a1) = self.split_u16x16(a); - self.combine_u8x16(self.reinterpret_u8_u16x8(a0), self.reinterpret_u8_u16x8(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u32_u16x16(self, a: u16x16) -> u32x8 { - let (a0, a1) = self.split_u16x16(a); - self.combine_u32x4( - self.reinterpret_u32_u16x8(a0), - self.reinterpret_u32_u16x8(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] - fn splat_mask16x16(self, a: i16) -> mask16x16 { - let half = self.splat_mask16x8(a); - self.combine_mask16x8(half, half) + fn splat_mask16x16(self, val: i16) -> mask16x16 { + unsafe { _mm256_set1_epi16(val).simd_into(self) } } #[inline(always)] fn not_mask16x16(self, a: mask16x16) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - self.combine_mask16x8(self.not_mask16x8(a0), self.not_mask16x8(a1)) + a ^ !0 } #[inline(always)] fn and_mask16x16(self, a: mask16x16, b: mask16x16) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_mask16x16(b); - self.combine_mask16x8(self.and_mask16x8(a0, b0), self.and_mask16x8(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_mask16x16(self, a: mask16x16, b: mask16x16) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_mask16x16(b); - self.combine_mask16x8(self.or_mask16x8(a0, b0), self.or_mask16x8(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_mask16x16(self, a: mask16x16, b: mask16x16) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_mask16x16(b); - self.combine_mask16x8(self.xor_mask16x8(a0, b0), self.xor_mask16x8(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn select_mask16x16( @@ -2400,19 +2281,11 @@ impl Simd for Avx2 { b: mask16x16, c: mask16x16, ) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_mask16x16(b); - let (c0, c1) = self.split_mask16x16(c); - self.combine_mask16x8( - self.select_mask16x8(a0, b0, c0), - self.select_mask16x8(a1, b1, c1), - ) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask16x16(self, a: mask16x16, b: mask16x16) -> mask16x16 { - let (a0, a1) = self.split_mask16x16(a); - let (b0, b1) = self.split_mask16x16(b); - self.combine_mask16x8(self.simd_eq_mask16x8(a0, b0), self.simd_eq_mask16x8(a1, b1)) + unsafe { _mm256_cmpeq_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_mask16x16(self, a: mask16x16, b: mask16x16) -> mask16x32 { @@ -2430,139 +2303,120 @@ impl Simd for Avx2 { (b0.simd_into(self), b1.simd_into(self)) } #[inline(always)] - fn splat_i32x8(self, a: i32) -> i32x8 { - let half = self.splat_i32x4(a); - self.combine_i32x4(half, half) + fn splat_i32x8(self, val: i32) -> i32x8 { + unsafe { _mm256_set1_epi32(val).simd_into(self) } } #[inline(always)] fn not_i32x8(self, a: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_i32x4(self.not_i32x4(a0), self.not_i32x4(a1)) + a ^ !0 } #[inline(always)] fn add_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.add_i32x4(a0, b0), self.add_i32x4(a1, b1)) + unsafe { _mm256_add_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.sub_i32x4(a0, b0), self.sub_i32x4(a1, b1)) + unsafe { _mm256_sub_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.mul_i32x4(a0, b0), self.mul_i32x4(a1, b1)) + unsafe { _mm256_mullo_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn and_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.and_i32x4(a0, b0), self.and_i32x4(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.or_i32x4(a0, b0), self.or_i32x4(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.xor_i32x4(a0, b0), self.xor_i32x4(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_i32x8(self, a: i32x8, b: u32) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_i32x4(self.shr_i32x4(a0, b), self.shr_i32x4(a1, b)) + fn shr_i32x8(self, a: i32x8, shift: u32) -> i32x8 { + unsafe { _mm256_sra_epi32(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn shrv_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.shrv_i32x4(a0, b0), self.shrv_i32x4(a1, b1)) + unsafe { _mm256_srav_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shl_i32x8(self, a: i32x8, b: u32) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_i32x4(self.shl_i32x4(a0, b), self.shl_i32x4(a1, b)) + fn shl_i32x8(self, a: i32x8, shift: u32) -> i32x8 { + unsafe { _mm256_sll_epi32(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn simd_eq_i32x8(self, a: i32x8, b: i32x8) -> mask32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_mask32x4(self.simd_eq_i32x4(a0, b0), self.simd_eq_i32x4(a1, b1)) + unsafe { _mm256_cmpeq_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_i32x8(self, a: i32x8, b: i32x8) -> mask32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_mask32x4(self.simd_lt_i32x4(a0, b0), self.simd_lt_i32x4(a1, b1)) + unsafe { _mm256_cmpgt_epi32(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i32x8(self, a: i32x8, b: i32x8) -> mask32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_mask32x4(self.simd_le_i32x4(a0, b0), self.simd_le_i32x4(a1, b1)) + unsafe { + _mm256_cmpeq_epi32(_mm256_min_epi32(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_ge_i32x8(self, a: i32x8, b: i32x8) -> mask32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_mask32x4(self.simd_ge_i32x4(a0, b0), self.simd_ge_i32x4(a1, b1)) + unsafe { + _mm256_cmpeq_epi32(_mm256_max_epi32(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_gt_i32x8(self, a: i32x8, b: i32x8) -> mask32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_mask32x4(self.simd_gt_i32x4(a0, b0), self.simd_gt_i32x4(a1, b1)) + unsafe { _mm256_cmpgt_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn zip_low_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, _) = self.split_i32x8(a); - let (b0, _) = self.split_i32x8(b); - self.combine_i32x4(self.zip_low_i32x4(a0, b0), self.zip_high_i32x4(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi32(a.into(), b.into()); + let hi = _mm256_unpackhi_epi32(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (_, a1) = self.split_i32x8(a); - let (_, b1) = self.split_i32x8(b); - self.combine_i32x4(self.zip_low_i32x4(a1, b1), self.zip_high_i32x4(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi32(a.into(), b.into()); + let hi = _mm256_unpackhi_epi32(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.unzip_low_i32x4(a0, a1), self.unzip_low_i32x4(b0, b1)) + unsafe { + let t1 = + _mm256_permutevar8x32_epi32(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = + _mm256_permutevar8x32_epi32(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2x128_si256::<0b0010_0000>(t1, t2).simd_into(self) + } } #[inline(always)] fn unzip_high_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.unzip_high_i32x4(a0, a1), self.unzip_high_i32x4(b0, b1)) + unsafe { + let t1 = + _mm256_permutevar8x32_epi32(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = + _mm256_permutevar8x32_epi32(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2x128_si256::<0b0011_0001>(t1, t2).simd_into(self) + } } #[inline(always)] fn select_i32x8(self, a: mask32x8, b: i32x8, c: i32x8) -> i32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_i32x8(b); - let (c0, c1) = self.split_i32x8(c); - self.combine_i32x4(self.select_i32x4(a0, b0, c0), self.select_i32x4(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.min_i32x4(a0, b0), self.min_i32x4(a1, b1)) + unsafe { _mm256_min_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_i32x8(self, a: i32x8, b: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - let (b0, b1) = self.split_i32x8(b); - self.combine_i32x4(self.max_i32x4(a0, b0), self.max_i32x4(a1, b1)) + unsafe { _mm256_max_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_i32x8(self, a: i32x8, b: i32x8) -> i32x16 { @@ -2581,161 +2435,151 @@ impl Simd for Avx2 { } #[inline(always)] fn neg_i32x8(self, a: i32x8) -> i32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_i32x4(self.neg_i32x4(a0), self.neg_i32x4(a1)) + unsafe { _mm256_sub_epi32(_mm256_setzero_si256(), a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_i32x8(self, a: i32x8) -> u8x32 { - let (a0, a1) = self.split_i32x8(a); - self.combine_u8x16(self.reinterpret_u8_i32x4(a0), self.reinterpret_u8_i32x4(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn reinterpret_u32_i32x8(self, a: i32x8) -> u32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_u32x4( - self.reinterpret_u32_i32x4(a0), - self.reinterpret_u32_i32x4(a1), - ) + u32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn cvt_f32_i32x8(self, a: i32x8) -> f32x8 { - let (a0, a1) = self.split_i32x8(a); - self.combine_f32x4(self.cvt_f32_i32x4(a0), self.cvt_f32_i32x4(a1)) + unsafe { _mm256_cvtepi32_ps(a.into()).simd_into(self) } } #[inline(always)] - fn splat_u32x8(self, a: u32) -> u32x8 { - let half = self.splat_u32x4(a); - self.combine_u32x4(half, half) + fn splat_u32x8(self, val: u32) -> u32x8 { + unsafe { _mm256_set1_epi32(val as _).simd_into(self) } } #[inline(always)] fn not_u32x8(self, a: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - self.combine_u32x4(self.not_u32x4(a0), self.not_u32x4(a1)) + a ^ !0 } #[inline(always)] fn add_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.add_u32x4(a0, b0), self.add_u32x4(a1, b1)) + unsafe { _mm256_add_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.sub_u32x4(a0, b0), self.sub_u32x4(a1, b1)) + unsafe { _mm256_sub_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.mul_u32x4(a0, b0), self.mul_u32x4(a1, b1)) + unsafe { _mm256_mullo_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn and_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.and_u32x4(a0, b0), self.and_u32x4(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.or_u32x4(a0, b0), self.or_u32x4(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.xor_u32x4(a0, b0), self.xor_u32x4(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shr_u32x8(self, a: u32x8, b: u32) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - self.combine_u32x4(self.shr_u32x4(a0, b), self.shr_u32x4(a1, b)) + fn shr_u32x8(self, a: u32x8, shift: u32) -> u32x8 { + unsafe { _mm256_srl_epi32(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn shrv_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.shrv_u32x4(a0, b0), self.shrv_u32x4(a1, b1)) + unsafe { _mm256_srlv_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] - fn shl_u32x8(self, a: u32x8, b: u32) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - self.combine_u32x4(self.shl_u32x4(a0, b), self.shl_u32x4(a1, b)) + fn shl_u32x8(self, a: u32x8, shift: u32) -> u32x8 { + unsafe { _mm256_sll_epi32(a.into(), _mm_cvtsi32_si128(shift as _)).simd_into(self) } } #[inline(always)] fn simd_eq_u32x8(self, a: u32x8, b: u32x8) -> mask32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_mask32x4(self.simd_eq_u32x4(a0, b0), self.simd_eq_u32x4(a1, b1)) + unsafe { _mm256_cmpeq_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u32x8(self, a: u32x8, b: u32x8) -> mask32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_mask32x4(self.simd_lt_u32x4(a0, b0), self.simd_lt_u32x4(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi32(0x80000000u32 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi32(b_signed, a_signed).simd_into(self) + } } #[inline(always)] fn simd_le_u32x8(self, a: u32x8, b: u32x8) -> mask32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_mask32x4(self.simd_le_u32x4(a0, b0), self.simd_le_u32x4(a1, b1)) + unsafe { + _mm256_cmpeq_epi32(_mm256_min_epu32(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_ge_u32x8(self, a: u32x8, b: u32x8) -> mask32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_mask32x4(self.simd_ge_u32x4(a0, b0), self.simd_ge_u32x4(a1, b1)) + unsafe { + _mm256_cmpeq_epi32(_mm256_max_epu32(a.into(), b.into()), a.into()).simd_into(self) + } } #[inline(always)] fn simd_gt_u32x8(self, a: u32x8, b: u32x8) -> mask32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_mask32x4(self.simd_gt_u32x4(a0, b0), self.simd_gt_u32x4(a1, b1)) + unsafe { + let sign_bit = _mm256_set1_epi32(0x80000000u32 as _); + let a_signed = _mm256_xor_si256(a.into(), sign_bit); + let b_signed = _mm256_xor_si256(b.into(), sign_bit); + _mm256_cmpgt_epi32(a_signed, b_signed).simd_into(self) + } } #[inline(always)] fn zip_low_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, _) = self.split_u32x8(a); - let (b0, _) = self.split_u32x8(b); - self.combine_u32x4(self.zip_low_u32x4(a0, b0), self.zip_high_u32x4(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_epi32(a.into(), b.into()); + let hi = _mm256_unpackhi_epi32(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (_, a1) = self.split_u32x8(a); - let (_, b1) = self.split_u32x8(b); - self.combine_u32x4(self.zip_low_u32x4(a1, b1), self.zip_high_u32x4(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_epi32(a.into(), b.into()); + let hi = _mm256_unpackhi_epi32(a.into(), b.into()); + _mm256_permute2x128_si256::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.unzip_low_u32x4(a0, a1), self.unzip_low_u32x4(b0, b1)) + unsafe { + let t1 = + _mm256_permutevar8x32_epi32(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = + _mm256_permutevar8x32_epi32(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2x128_si256::<0b0010_0000>(t1, t2).simd_into(self) + } } #[inline(always)] fn unzip_high_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.unzip_high_u32x4(a0, a1), self.unzip_high_u32x4(b0, b1)) + unsafe { + let t1 = + _mm256_permutevar8x32_epi32(a.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + let t2 = + _mm256_permutevar8x32_epi32(b.into(), _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); + _mm256_permute2x128_si256::<0b0011_0001>(t1, t2).simd_into(self) + } } #[inline(always)] fn select_u32x8(self, a: mask32x8, b: u32x8, c: u32x8) -> u32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_u32x8(b); - let (c0, c1) = self.split_u32x8(c); - self.combine_u32x4(self.select_u32x4(a0, b0, c0), self.select_u32x4(a1, b1, c1)) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.min_u32x4(a0, b0), self.min_u32x4(a1, b1)) + unsafe { _mm256_min_epu32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_u32x8(self, a: u32x8, b: u32x8) -> u32x8 { - let (a0, a1) = self.split_u32x8(a); - let (b0, b1) = self.split_u32x8(b); - self.combine_u32x4(self.max_u32x4(a0, b0), self.max_u32x4(a1, b1)) + unsafe { _mm256_max_epu32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_u32x8(self, a: u32x8, b: u32x8) -> u32x16 { @@ -2754,41 +2598,34 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u32x8(self, a: u32x8) -> u8x32 { - let (a0, a1) = self.split_u32x8(a); - self.combine_u8x16(self.reinterpret_u8_u32x4(a0), self.reinterpret_u8_u32x4(a1)) + u8x32 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] fn cvt_f32_u32x8(self, a: u32x8) -> f32x8 { - let (a0, a1) = self.split_u32x8(a); - self.combine_f32x4(self.cvt_f32_u32x4(a0), self.cvt_f32_u32x4(a1)) + unsafe { _mm256_cvtepi32_ps(a.into()).simd_into(self) } } #[inline(always)] - fn splat_mask32x8(self, a: i32) -> mask32x8 { - let half = self.splat_mask32x4(a); - self.combine_mask32x4(half, half) + fn splat_mask32x8(self, val: i32) -> mask32x8 { + unsafe { _mm256_set1_epi32(val).simd_into(self) } } #[inline(always)] fn not_mask32x8(self, a: mask32x8) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - self.combine_mask32x4(self.not_mask32x4(a0), self.not_mask32x4(a1)) + a ^ !0 } #[inline(always)] fn and_mask32x8(self, a: mask32x8, b: mask32x8) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_mask32x8(b); - self.combine_mask32x4(self.and_mask32x4(a0, b0), self.and_mask32x4(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_mask32x8(self, a: mask32x8, b: mask32x8) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_mask32x8(b); - self.combine_mask32x4(self.or_mask32x4(a0, b0), self.or_mask32x4(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_mask32x8(self, a: mask32x8, b: mask32x8) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_mask32x8(b); - self.combine_mask32x4(self.xor_mask32x4(a0, b0), self.xor_mask32x4(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn select_mask32x8( @@ -2797,19 +2634,11 @@ impl Simd for Avx2 { b: mask32x8, c: mask32x8, ) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_mask32x8(b); - let (c0, c1) = self.split_mask32x8(c); - self.combine_mask32x4( - self.select_mask32x4(a0, b0, c0), - self.select_mask32x4(a1, b1, c1), - ) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask32x8(self, a: mask32x8, b: mask32x8) -> mask32x8 { - let (a0, a1) = self.split_mask32x8(a); - let (b0, b1) = self.split_mask32x8(b); - self.combine_mask32x4(self.simd_eq_mask32x4(a0, b0), self.simd_eq_mask32x4(a1, b1)) + unsafe { _mm256_cmpeq_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_mask32x8(self, a: mask32x8, b: mask32x8) -> mask32x16 { @@ -2827,174 +2656,141 @@ impl Simd for Avx2 { (b0.simd_into(self), b1.simd_into(self)) } #[inline(always)] - fn splat_f64x4(self, a: f64) -> f64x4 { - let half = self.splat_f64x2(a); - self.combine_f64x2(half, half) + fn splat_f64x4(self, val: f64) -> f64x4 { + unsafe { _mm256_set1_pd(val).simd_into(self) } } #[inline(always)] fn abs_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.abs_f64x2(a0), self.abs_f64x2(a1)) + unsafe { _mm256_andnot_pd(_mm256_set1_pd(-0.0), a.into()).simd_into(self) } } #[inline(always)] fn neg_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.neg_f64x2(a0), self.neg_f64x2(a1)) + unsafe { _mm256_xor_pd(a.into(), _mm256_set1_pd(-0.0)).simd_into(self) } } #[inline(always)] fn sqrt_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.sqrt_f64x2(a0), self.sqrt_f64x2(a1)) + unsafe { _mm256_sqrt_pd(a.into()).simd_into(self) } } #[inline(always)] fn add_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.add_f64x2(a0, b0), self.add_f64x2(a1, b1)) + unsafe { _mm256_add_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn sub_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.sub_f64x2(a0, b0), self.sub_f64x2(a1, b1)) + unsafe { _mm256_sub_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn mul_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.mul_f64x2(a0, b0), self.mul_f64x2(a1, b1)) + unsafe { _mm256_mul_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn div_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.div_f64x2(a0, b0), self.div_f64x2(a1, b1)) + unsafe { _mm256_div_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn copysign_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.copysign_f64x2(a0, b0), self.copysign_f64x2(a1, b1)) + unsafe { + let mask = _mm256_set1_pd(-0.0); + _mm256_or_pd( + _mm256_and_pd(mask, b.into()), + _mm256_andnot_pd(mask, a.into()), + ) + .simd_into(self) + } } #[inline(always)] fn simd_eq_f64x4(self, a: f64x4, b: f64x4) -> mask64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_mask64x2(self.simd_eq_f64x2(a0, b0), self.simd_eq_f64x2(a1, b1)) + unsafe { _mm256_castpd_si256(_mm256_cmp_pd::<0i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_lt_f64x4(self, a: f64x4, b: f64x4) -> mask64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_mask64x2(self.simd_lt_f64x2(a0, b0), self.simd_lt_f64x2(a1, b1)) + unsafe { _mm256_castpd_si256(_mm256_cmp_pd::<17i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_le_f64x4(self, a: f64x4, b: f64x4) -> mask64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_mask64x2(self.simd_le_f64x2(a0, b0), self.simd_le_f64x2(a1, b1)) + unsafe { _mm256_castpd_si256(_mm256_cmp_pd::<18i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_ge_f64x4(self, a: f64x4, b: f64x4) -> mask64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_mask64x2(self.simd_ge_f64x2(a0, b0), self.simd_ge_f64x2(a1, b1)) + unsafe { _mm256_castpd_si256(_mm256_cmp_pd::<29i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn simd_gt_f64x4(self, a: f64x4, b: f64x4) -> mask64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_mask64x2(self.simd_gt_f64x2(a0, b0), self.simd_gt_f64x2(a1, b1)) + unsafe { _mm256_castpd_si256(_mm256_cmp_pd::<30i32>(a.into(), b.into())).simd_into(self) } } #[inline(always)] fn zip_low_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, _) = self.split_f64x4(a); - let (b0, _) = self.split_f64x4(b); - self.combine_f64x2(self.zip_low_f64x2(a0, b0), self.zip_high_f64x2(a0, b0)) + unsafe { + let lo = _mm256_unpacklo_pd(a.into(), b.into()); + let hi = _mm256_unpackhi_pd(a.into(), b.into()); + _mm256_permute2f128_pd::<0b0010_0000>(lo, hi).simd_into(self) + } } #[inline(always)] fn zip_high_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (_, a1) = self.split_f64x4(a); - let (_, b1) = self.split_f64x4(b); - self.combine_f64x2(self.zip_low_f64x2(a1, b1), self.zip_high_f64x2(a1, b1)) + unsafe { + let lo = _mm256_unpacklo_pd(a.into(), b.into()); + let hi = _mm256_unpackhi_pd(a.into(), b.into()); + _mm256_permute2f128_pd::<0b0011_0001>(lo, hi).simd_into(self) + } } #[inline(always)] fn unzip_low_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.unzip_low_f64x2(a0, a1), self.unzip_low_f64x2(b0, b1)) + unsafe { + let t1 = _mm256_permute4x64_pd::<0b11_01_10_00>(a.into()); + let t2 = _mm256_permute4x64_pd::<0b11_01_10_00>(b.into()); + _mm256_permute2f128_pd::<0b0010_0000>(t1, t2).simd_into(self) + } } #[inline(always)] fn unzip_high_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.unzip_high_f64x2(a0, a1), self.unzip_high_f64x2(b0, b1)) + unsafe { + let t1 = _mm256_permute4x64_pd::<0b11_01_10_00>(a.into()); + let t2 = _mm256_permute4x64_pd::<0b11_01_10_00>(b.into()); + _mm256_permute2f128_pd::<0b0011_0001>(t1, t2).simd_into(self) + } } #[inline(always)] fn max_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.max_f64x2(a0, b0), self.max_f64x2(a1, b1)) + unsafe { _mm256_max_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn max_precise_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2( - self.max_precise_f64x2(a0, b0), - self.max_precise_f64x2(a1, b1), - ) + unsafe { _mm256_max_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn min_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2(self.min_f64x2(a0, b0), self.min_f64x2(a1, b1)) + unsafe { _mm256_min_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn min_precise_f64x4(self, a: f64x4, b: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - self.combine_f64x2( - self.min_precise_f64x2(a0, b0), - self.min_precise_f64x2(a1, b1), - ) + unsafe { _mm256_min_pd(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn madd_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - let (c0, c1) = self.split_f64x4(c); - self.combine_f64x2(self.madd_f64x2(a0, b0, c0), self.madd_f64x2(a1, b1, c1)) + unsafe { _mm256_fmadd_pd(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] fn msub_f64x4(self, a: f64x4, b: f64x4, c: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - let (b0, b1) = self.split_f64x4(b); - let (c0, c1) = self.split_f64x4(c); - self.combine_f64x2(self.msub_f64x2(a0, b0, c0), self.msub_f64x2(a1, b1, c1)) + unsafe { _mm256_fmsub_pd(a.into(), b.into(), c.into()).simd_into(self) } } #[inline(always)] fn floor_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.floor_f64x2(a0), self.floor_f64x2(a1)) + unsafe { _mm256_floor_pd(a.into()).simd_into(self) } } #[inline(always)] fn fract_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.fract_f64x2(a0), self.fract_f64x2(a1)) + a - a.trunc() } #[inline(always)] fn trunc_f64x4(self, a: f64x4) -> f64x4 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f64x2(self.trunc_f64x2(a0), self.trunc_f64x2(a1)) + unsafe { _mm256_round_pd(a.into(), _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC).simd_into(self) } } #[inline(always)] fn select_f64x4(self, a: mask64x4, b: f64x4, c: f64x4) -> f64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_f64x4(b); - let (c0, c1) = self.split_f64x4(c); - self.combine_f64x2(self.select_f64x2(a0, b0, c0), self.select_f64x2(a1, b1, c1)) + unsafe { + _mm256_blendv_pd(c.into(), b.into(), _mm256_castsi256_pd(a.into())).simd_into(self) + } } #[inline(always)] fn combine_f64x4(self, a: f64x4, b: f64x4) -> f64x8 { @@ -3013,39 +2809,30 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f32_f64x4(self, a: f64x4) -> f32x8 { - let (a0, a1) = self.split_f64x4(a); - self.combine_f32x4( - self.reinterpret_f32_f64x2(a0), - self.reinterpret_f32_f64x2(a1), - ) + f32x8 { + val: bytemuck::cast(a.val), + simd: a.simd, + } } #[inline(always)] - fn splat_mask64x4(self, a: i64) -> mask64x4 { - let half = self.splat_mask64x2(a); - self.combine_mask64x2(half, half) + fn splat_mask64x4(self, val: i64) -> mask64x4 { + unsafe { _mm256_set1_epi64x(val).simd_into(self) } } #[inline(always)] fn not_mask64x4(self, a: mask64x4) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - self.combine_mask64x2(self.not_mask64x2(a0), self.not_mask64x2(a1)) + a ^ !0 } #[inline(always)] fn and_mask64x4(self, a: mask64x4, b: mask64x4) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_mask64x4(b); - self.combine_mask64x2(self.and_mask64x2(a0, b0), self.and_mask64x2(a1, b1)) + unsafe { _mm256_and_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn or_mask64x4(self, a: mask64x4, b: mask64x4) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_mask64x4(b); - self.combine_mask64x2(self.or_mask64x2(a0, b0), self.or_mask64x2(a1, b1)) + unsafe { _mm256_or_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn xor_mask64x4(self, a: mask64x4, b: mask64x4) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_mask64x4(b); - self.combine_mask64x2(self.xor_mask64x2(a0, b0), self.xor_mask64x2(a1, b1)) + unsafe { _mm256_xor_si256(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn select_mask64x4( @@ -3054,19 +2841,11 @@ impl Simd for Avx2 { b: mask64x4, c: mask64x4, ) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_mask64x4(b); - let (c0, c1) = self.split_mask64x4(c); - self.combine_mask64x2( - self.select_mask64x2(a0, b0, c0), - self.select_mask64x2(a1, b1, c1), - ) + unsafe { _mm256_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask64x4(self, a: mask64x4, b: mask64x4) -> mask64x4 { - let (a0, a1) = self.split_mask64x4(a); - let (b0, b1) = self.split_mask64x4(b); - self.combine_mask64x2(self.simd_eq_mask64x2(a0, b0), self.simd_eq_mask64x2(a1, b1)) + unsafe { _mm256_cmpeq_epi64(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn combine_mask64x4(self, a: mask64x4, b: mask64x4) -> mask64x8 { diff --git a/fearless_simd/src/generated/sse4_2.rs b/fearless_simd/src/generated/sse4_2.rs index 6496ecb1..c9250c89 100644 --- a/fearless_simd/src/generated/sse4_2.rs +++ b/fearless_simd/src/generated/sse4_2.rs @@ -181,10 +181,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn select_f32x4(self, a: mask32x4, b: f32x4, c: f32x4) -> f32x4 { - unsafe { - let mask = _mm_castsi128_ps(a.into()); - _mm_or_ps(_mm_and_ps(mask, b.into()), _mm_andnot_ps(mask, c.into())).simd_into(self) - } + unsafe { _mm_blendv_ps(c.into(), b.into(), _mm_castsi128_ps(a.into())).simd_into(self) } } #[inline(always)] fn combine_f32x4(self, a: f32x4, b: f32x4) -> f32x8 { @@ -268,8 +265,8 @@ impl Simd for Sse4_2 { unsafe { let val = a.into(); let shift_count = _mm_cvtsi32_si128(shift as i32); - let lo_16 = _mm_unpacklo_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); - let hi_16 = _mm_unpackhi_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); + let lo_16 = _mm_unpacklo_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); + let hi_16 = _mm_unpackhi_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); let lo_shifted = _mm_sra_epi16(lo_16, shift_count); let hi_shifted = _mm_sra_epi16(hi_16, shift_count); _mm_packs_epi16(lo_shifted, hi_shifted).simd_into(self) @@ -284,8 +281,8 @@ impl Simd for Sse4_2 { unsafe { let val = a.into(); let shift_count = _mm_cvtsi32_si128(shift as i32); - let lo_16 = _mm_unpacklo_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); - let hi_16 = _mm_unpackhi_epi8(val, _mm_cmplt_epi8(val, _mm_setzero_si128())); + let lo_16 = _mm_unpacklo_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); + let hi_16 = _mm_unpackhi_epi8(val, _mm_cmpgt_epi8(_mm_setzero_si128(), val)); let lo_shifted = _mm_sll_epi16(lo_16, shift_count); let hi_shifted = _mm_sll_epi16(hi_16, shift_count); _mm_packs_epi16(lo_shifted, hi_shifted).simd_into(self) @@ -297,7 +294,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_lt_i8x16(self, a: i8x16, b: i8x16) -> mask8x16 { - unsafe { _mm_cmplt_epi8(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi8(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i8x16(self, a: i8x16, b: i8x16) -> mask8x16 { @@ -322,7 +319,7 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_low_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { unsafe { - let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -331,21 +328,15 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_high_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { unsafe { - let mask = _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_i8x16(self, a: mask8x16, b: i8x16, c: i8x16) -> i8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { @@ -442,12 +433,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_eq_u8x16(self, a: u8x16, b: u8x16) -> mask8x16 { - unsafe { - let sign_bit = _mm_set1_epi8(0x80u8 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi8(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi8(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u8x16(self, a: u8x16, b: u8x16) -> mask8x16 { @@ -486,7 +472,7 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_low_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { unsafe { - let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -495,21 +481,15 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_high_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { unsafe { - let mask = _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15); + let mask = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_u8x16(self, a: mask8x16, b: u8x16, c: u8x16) -> u8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { @@ -569,13 +549,7 @@ impl Simd for Sse4_2 { b: mask8x16, c: mask8x16, ) -> mask8x16 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask8x16(self, a: mask8x16, b: mask8x16) -> mask8x16 { @@ -638,7 +612,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_lt_i16x8(self, a: i16x8, b: i16x8) -> mask16x8 { - unsafe { _mm_cmplt_epi16(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi16(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i16x8(self, a: i16x8, b: i16x8) -> mask16x8 { @@ -663,7 +637,7 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_low_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { unsafe { - let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -672,21 +646,15 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_high_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { unsafe { - let mask = _mm_setr_epi8(2, 3, 6, 7, 10, 11, 14, 15, 2, 3, 6, 7, 10, 11, 14, 15); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_i16x8(self, a: mask16x8, b: i16x8, c: i16x8) -> i16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i16x8(self, a: i16x8, b: i16x8) -> i16x8 { @@ -767,12 +735,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_eq_u16x8(self, a: u16x8, b: u16x8) -> mask16x8 { - unsafe { - let sign_bit = _mm_set1_epi16(0x8000u16 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi16(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi16(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u16x8(self, a: u16x8, b: u16x8) -> mask16x8 { @@ -811,7 +774,7 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_low_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { unsafe { - let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); _mm_unpacklo_epi64(t1, t2).simd_into(self) @@ -820,21 +783,15 @@ impl Simd for Sse4_2 { #[inline(always)] fn unzip_high_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { unsafe { - let mask = _mm_setr_epi8(2, 3, 6, 7, 10, 11, 14, 15, 2, 3, 6, 7, 10, 11, 14, 15); + let mask = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15); let t1 = _mm_shuffle_epi8(a.into(), mask); let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) + _mm_unpackhi_epi64(t1, t2).simd_into(self) } } #[inline(always)] fn select_u16x8(self, a: mask16x8, b: u16x8, c: u16x8) -> u16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u16x8(self, a: u16x8, b: u16x8) -> u16x8 { @@ -892,13 +849,7 @@ impl Simd for Sse4_2 { b: mask16x8, c: mask16x8, ) -> mask16x8 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask16x8(self, a: mask16x8, b: mask16x8) -> mask16x8 { @@ -961,7 +912,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_lt_i32x4(self, a: i32x4, b: i32x4) -> mask32x4 { - unsafe { _mm_cmplt_epi32(a.into(), b.into()).simd_into(self) } + unsafe { _mm_cmpgt_epi32(b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_le_i32x4(self, a: i32x4, b: i32x4) -> mask32x4 { @@ -1001,13 +952,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn select_i32x4(self, a: mask32x4, b: i32x4, c: i32x4) -> i32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_i32x4(self, a: i32x4, b: i32x4) -> i32x4 { @@ -1092,12 +1037,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn simd_eq_u32x4(self, a: u32x4, b: u32x4) -> mask32x4 { - unsafe { - let sign_bit = _mm_set1_epi32(0x80000000u32 as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); - _mm_cmpgt_epi32(a_signed, b_signed).simd_into(self) - } + unsafe { _mm_cmpeq_epi32(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn simd_lt_u32x4(self, a: u32x4, b: u32x4) -> mask32x4 { @@ -1151,13 +1091,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn select_u32x4(self, a: mask32x4, b: u32x4, c: u32x4) -> u32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn min_u32x4(self, a: u32x4, b: u32x4) -> u32x4 { @@ -1212,13 +1146,7 @@ impl Simd for Sse4_2 { b: mask32x4, c: mask32x4, ) -> mask32x4 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask32x4(self, a: mask32x4, b: mask32x4) -> mask32x4 { @@ -1344,10 +1272,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn select_f64x2(self, a: mask64x2, b: f64x2, c: f64x2) -> f64x2 { - unsafe { - let mask = _mm_castsi128_pd(a.into()); - _mm_or_pd(_mm_and_pd(mask, b.into()), _mm_andnot_pd(mask, c.into())).simd_into(self) - } + unsafe { _mm_blendv_pd(c.into(), b.into(), _mm_castsi128_pd(a.into())).simd_into(self) } } #[inline(always)] fn combine_f64x2(self, a: f64x2, b: f64x2) -> f64x4 { @@ -1390,13 +1315,7 @@ impl Simd for Sse4_2 { b: mask64x2, c: mask64x2, ) -> mask64x2 { - unsafe { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()), - ) - .simd_into(self) - } + unsafe { _mm_blendv_epi8(c.into(), b.into(), a.into()).simd_into(self) } } #[inline(always)] fn simd_eq_mask64x2(self, a: mask64x2, b: mask64x2) -> mask64x2 { diff --git a/fearless_simd_gen/src/arch/x86_common.rs b/fearless_simd_gen/src/arch/x86_common.rs index 41450470..615cc77d 100644 --- a/fearless_simd_gen/src/arch/x86_common.rs +++ b/fearless_simd_gen/src/arch/x86_common.rs @@ -3,7 +3,7 @@ use crate::types::{ScalarType, VecType}; use crate::x86_common::{ - intrinsic_ident, op_suffix, set0_intrinsic, set1_intrinsic, simple_intrinsic, + coarse_type, intrinsic_ident, op_suffix, set1_intrinsic, simple_intrinsic, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -29,6 +29,7 @@ pub(crate) fn translate_op(op: &str) -> Option<&'static str> { "min" => "min", "max_precise" => "max", "min_precise" => "min", + "select" => "blendv", _ => return None, }) } @@ -50,7 +51,8 @@ pub(crate) fn expr(op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream let sign_aware = matches!(op, "max" | "min"); let suffix = match op_name { - "and" | "or" | "xor" => "si128", + "and" | "or" | "xor" => coarse_type(*ty), + "blendv" if ty.scalar != ScalarType::Float => "epi8", _ => op_suffix(ty.scalar, ty.scalar_bits, sign_aware), }; let intrinsic = intrinsic_ident(op_name, suffix, ty.n_bits()); @@ -72,7 +74,7 @@ pub(crate) fn expr(op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream } } ScalarType::Int => { - let set0 = set0_intrinsic(*ty); + let set0 = intrinsic_ident("setzero", coarse_type(*ty), ty.n_bits()); let sub = simple_intrinsic("sub", ty.scalar, ty.scalar_bits, ty.n_bits()); let arg = &args[0]; quote! { diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 9d501fd9..a205b382 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -7,8 +7,8 @@ use crate::arch::sse4_2::Sse4_2; use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; use crate::mk_sse4_2; use crate::ops::{OpSig, TyFlavor, ops_for_type}; -use crate::types::{SIMD_TYPES, VecType, type_imports}; -use crate::x86_common::simple_intrinsic; +use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; +use crate::x86_common::{cast_ident, simple_intrinsic}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -82,20 +82,17 @@ fn mk_simd_impl() -> TokenStream { let mut methods = vec![]; for vec_ty in SIMD_TYPES { for (method, sig) in ops_for_type(vec_ty, true) { - // TODO: Right now, we are basically adding the same methods as for SSE4.2 (except for - // FMA). In the future, we'll obviously want to use AVX2 intrinsics for 256 bit. - let b1 = (vec_ty.n_bits() > 128 && !matches!(method, "split" | "narrow")) - || vec_ty.n_bits() > 256; + let too_wide = vec_ty.n_bits() > 256; - let b2 = !matches!(method, "load_interleaved_128") - && !matches!(method, "store_interleaved_128"); + let acceptable_wide_op = matches!(method, "load_interleaved_128") + || matches!(method, "store_interleaved_128"); - if b1 && b2 { + if too_wide && !acceptable_wide_op { methods.push(generic_op(method, sig, vec_ty)); continue; } - let method = make_method(method, sig, vec_ty, Sse4_2, 128); + let method = make_method(method, sig, vec_ty, Sse4_2, vec_ty.n_bits()); methods.push(method); } @@ -191,12 +188,14 @@ fn make_method( match sig { OpSig::Splat => mk_sse4_2::handle_splat(method_sig, vec_ty, scalar_bits, ty_bits), - OpSig::Compare => { - mk_sse4_2::handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits, arch) - } + OpSig::Compare => handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits, arch), OpSig::Unary => mk_sse4_2::handle_unary(method_sig, method, vec_ty, arch), OpSig::WidenNarrow(t) => { - mk_sse4_2::handle_widen_narrow(method_sig, method, vec_ty, scalar_bits, ty_bits, t) + if vec_ty.n_bits() > 128 && method == "widen" { + generic_op(method, sig, vec_ty) + } else { + mk_sse4_2::handle_widen_narrow(method_sig, method, vec_ty, scalar_bits, ty_bits, t) + } } OpSig::Binary => mk_sse4_2::handle_binary(method_sig, method, vec_ty, arch), OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty, scalar_bits, ty_bits), @@ -242,3 +241,37 @@ fn make_method( } } } + +pub(crate) fn handle_compare( + method_sig: TokenStream, + method: &str, + vec_ty: &VecType, + scalar_bits: usize, + ty_bits: usize, + arch: impl Arch, +) -> TokenStream { + if vec_ty.scalar == ScalarType::Float { + // For AVX2 and up, Intel gives us a generic comparison intrinsic that takes a predicate. There are 32, + // of which only a few are useful and the rest will violate IEEE754 and/or raise a SIGFPE on NaN. + // + // https://www.felixcloutier.com/x86/cmppd#tbl-3-1 + let order_predicate = match method { + "simd_eq" => 0x00, + "simd_lt" => 0x11, + "simd_le" => 0x12, + "simd_ge" => 0x1D, + "simd_gt" => 0x1E, + _ => unreachable!(), + }; + let intrinsic = simple_intrinsic("cmp", vec_ty.scalar, scalar_bits, ty_bits); + let cast = cast_ident(ScalarType::Float, ScalarType::Mask, scalar_bits, ty_bits); + + quote! { + #method_sig { + unsafe { #cast(#intrinsic::<#order_predicate>(a.into(), b.into())).simd_into(self) } + } + } + } else { + mk_sse4_2::handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits, arch) + } +} diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index 3baf0b55..ac2aed6f 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -7,8 +7,9 @@ use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; use crate::ops::{OpSig, TyFlavor, ops_for_type, reinterpret_ty, valid_reinterpret}; use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; use crate::x86_common::{ - cvt_intrinsic, extend_intrinsic, op_suffix, pack_intrinsic, set1_intrinsic, simple_intrinsic, - simple_sign_unaware_intrinsic, unpack_intrinsic, + cast_ident, coarse_type, cvt_intrinsic, extend_intrinsic, intrinsic_ident, op_suffix, + pack_intrinsic, set1_intrinsic, simple_intrinsic, simple_sign_unaware_intrinsic, + unpack_intrinsic, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; @@ -250,56 +251,78 @@ pub(crate) fn handle_compare( ) -> TokenStream { let args = [quote! { a.into() }, quote! { b.into() }]; - let mut expr = if vec_ty.scalar != ScalarType::Float { - if matches!(method, "simd_le" | "simd_ge") { - let max_min = match method { - "simd_le" => "min", - "simd_ge" => "max", - _ => unreachable!(), - }; - - let eq_intrinsic = - simple_sign_unaware_intrinsic("cmpeq", vec_ty.scalar, vec_ty.scalar_bits, ty_bits); - - let max_min_expr = arch.expr(max_min, vec_ty, &args); - quote! { #eq_intrinsic(#max_min_expr, a.into()) } - } else if vec_ty.scalar == ScalarType::Unsigned { - // SSE4.2 only has signed GT/LT, but not unsigned. - let set = set1_intrinsic(vec_ty.scalar, vec_ty.scalar_bits, ty_bits); - let sign = match vec_ty.scalar_bits { - 8 => quote! { 0x80u8 }, - 16 => quote! { 0x8000u16 }, - 32 => quote! { 0x80000000u32 }, - _ => unimplemented!(), - }; - let gt = - simple_sign_unaware_intrinsic("cmpgt", vec_ty.scalar, vec_ty.scalar_bits, ty_bits); - let args = if method == "simd_lt" { - quote! { b_signed, a_signed } - } else { - quote! { a_signed, b_signed } - }; - - quote! { - let sign_bit = #set(#sign as _); - let a_signed = _mm_xor_si128(a.into(), sign_bit); - let b_signed = _mm_xor_si128(b.into(), sign_bit); + let expr = if vec_ty.scalar != ScalarType::Float { + match method { + "simd_le" | "simd_ge" => { + let max_min = match method { + "simd_le" => "min", + "simd_ge" => "max", + _ => unreachable!(), + }; - #gt(#args) + // TODO: in some places, we use vec_ty.scalar bits, and in other places, we use the scalar_bits argument. + // AFAIK, these never differ. + let eq_intrinsic = simple_sign_unaware_intrinsic( + "cmpeq", + vec_ty.scalar, + vec_ty.scalar_bits, + ty_bits, + ); + + let max_min_expr = arch.expr(max_min, vec_ty, &args); + quote! { #eq_intrinsic(#max_min_expr, a.into()) } } - } else { - arch.expr(method, vec_ty, &args) + "simd_lt" | "simd_gt" => { + let gt = simple_sign_unaware_intrinsic( + "cmpgt", + vec_ty.scalar, + vec_ty.scalar_bits, + ty_bits, + ); + + if vec_ty.scalar == ScalarType::Unsigned { + // SSE4.2 only has signed GT/LT, but not unsigned. + let set = set1_intrinsic(vec_ty.scalar, vec_ty.scalar_bits, ty_bits); + let sign = match vec_ty.scalar_bits { + 8 => quote! { 0x80u8 }, + 16 => quote! { 0x8000u16 }, + 32 => quote! { 0x80000000u32 }, + _ => unimplemented!(), + }; + let xor_op = intrinsic_ident("xor", coarse_type(*vec_ty), ty_bits); + let args = if method == "simd_lt" { + quote! { b_signed, a_signed } + } else { + quote! { a_signed, b_signed } + }; + + quote! { + let sign_bit = #set(#sign as _); + let a_signed = #xor_op(a.into(), sign_bit); + let b_signed = #xor_op(b.into(), sign_bit); + + #gt(#args) + } + } else { + let args = if method == "simd_lt" { + quote! { b.into(), a.into() } + } else { + quote! { a.into(), b.into() } + }; + quote! { + #gt(#args) + } + } + } + "simd_eq" => arch.expr(method, vec_ty, &args), + _ => unreachable!(), } } else { - arch.expr(method, vec_ty, &args) + let expr = arch.expr(method, vec_ty, &args); + let ident = cast_ident(ScalarType::Float, ScalarType::Mask, scalar_bits, ty_bits); + quote! { #ident(#expr) } }; - if vec_ty.scalar == ScalarType::Float { - let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); - let ident = format_ident!("_mm_cast{suffix}_si128"); - expr = quote! { #ident(#expr) } - } - quote! { #method_sig { unsafe { #expr.simd_into(self) } @@ -374,11 +397,11 @@ pub(crate) fn handle_widen_narrow( } } "narrow" => { - let mask = set1_intrinsic(vec_ty.scalar, scalar_bits, ty_bits); + let mask = set1_intrinsic(vec_ty.scalar, scalar_bits, t.n_bits()); let pack = pack_intrinsic( scalar_bits, matches!(vec_ty.scalar, ScalarType::Int), - ty_bits, + t.n_bits(), ); let split = format_ident!("split_{}", vec_ty.rust_name()); quote! { @@ -437,7 +460,7 @@ pub(crate) fn handle_shift( _ => unreachable!(), }; let suffix = op_suffix(vec_ty.scalar, scalar_bits.max(16), false); - let shift_intrinsic = format_ident!("_mm_{op}_{suffix}"); + let shift_intrinsic = intrinsic_ident(op, suffix, ty_bits); if scalar_bits == 8 { // SSE doesn't have shifting for 8-bit, so we first convert into @@ -446,13 +469,17 @@ pub(crate) fn handle_shift( let unpack_hi = unpack_intrinsic(ScalarType::Int, 8, false, ty_bits); let unpack_lo = unpack_intrinsic(ScalarType::Int, 8, true, ty_bits); + let set0 = intrinsic_ident("setzero", coarse_type(*vec_ty), ty_bits); let extend_expr = |expr| match vec_ty.scalar { ScalarType::Unsigned => quote! { - #expr(val, _mm_setzero_si128()) - }, - ScalarType::Int => quote! { - #expr(val, _mm_cmplt_epi8(val, _mm_setzero_si128())) + #expr(val, #set0()) }, + ScalarType::Int => { + let cmp_intrinsic = intrinsic_ident("cmpgt", "epi8", ty_bits); + quote! { + #expr(val, #cmp_intrinsic(#set0(), val)) + } + } _ => unimplemented!(), }; @@ -528,36 +555,28 @@ pub(crate) fn handle_select( vec_ty: &VecType, scalar_bits: usize, ) -> TokenStream { - let expr = if vec_ty.scalar == ScalarType::Float { - let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); - let (i1, i2, i3, i4) = ( - format_ident!("_mm_castsi128_{suffix}"), - format_ident!("_mm_or_{suffix}"), - format_ident!("_mm_and_{suffix}"), - format_ident!("_mm_andnot_{suffix}"), - ); - quote! { - let mask = #i1(a.into()); - - #i2( - #i3(mask, b.into()), - #i4(mask, c.into()) - ) - } - } else { - quote! { - _mm_or_si128( - _mm_and_si128(a.into(), b.into()), - _mm_andnot_si128(a.into(), c.into()) - ) - } - }; + // Our select ops' argument order is mask, a, b; Intel's intrinsics are b, a, mask + let args = [ + quote! { c.into() }, + quote! { b.into() }, + match vec_ty.scalar { + ScalarType::Float => { + let ident = cast_ident( + ScalarType::Mask, + ScalarType::Float, + scalar_bits, + vec_ty.n_bits(), + ); + quote! { #ident(a.into()) } + } + _ => quote! { a.into() }, + }, + ]; + let expr = Sse4_2.expr("select", vec_ty, &args); quote! { #method_sig { - unsafe { - #expr.simd_into(self) - } + unsafe { #expr.simd_into(self) } } } } @@ -568,14 +587,50 @@ pub(crate) fn handle_zip( scalar_bits: usize, zip1: bool, ) -> TokenStream { - let op = if zip1 { "lo" } else { "hi" }; + let expr = match vec_ty.n_bits() { + 128 => { + let op = if zip1 { "unpacklo" } else { "unpackhi" }; + + let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); + let unpack_intrinsic = intrinsic_ident(op, suffix, vec_ty.n_bits()); + quote! { + unsafe { #unpack_intrinsic(a.into(), b.into()).simd_into(self) } + } + } + 256 => { + let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); + let lo = intrinsic_ident("unpacklo", suffix, vec_ty.n_bits()); + let hi = intrinsic_ident("unpackhi", suffix, vec_ty.n_bits()); + let shuffle_immediate = if zip1 { + quote! { 0b0010_0000 } + } else { + quote! { 0b0011_0001 } + }; - let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); - let intrinsic = format_ident!("_mm_unpack{op}_{suffix}"); + let shuffle = intrinsic_ident( + match vec_ty.scalar { + ScalarType::Float => "permute2f128", + _ => "permute2x128", + }, + coarse_type(*vec_ty), + 256, + ); + + quote! { + unsafe { + let lo = #lo(a.into(), b.into()); + let hi = #hi(a.into(), b.into()); + + #shuffle::<#shuffle_immediate>(lo, hi).simd_into(self) + } + } + } + _ => unreachable!(), + }; quote! { #method_sig { - unsafe { #intrinsic(a.into(), b.into()).simd_into(self) } + #expr } } } @@ -586,63 +641,149 @@ pub(crate) fn handle_unzip( scalar_bits: usize, select_even: bool, ) -> TokenStream { - let expr = if vec_ty.scalar == ScalarType::Float { - let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); - let intrinsic = format_ident!("_mm_shuffle_{suffix}"); - - let mask = match (vec_ty.scalar_bits, select_even) { - (32, true) => quote! { 0b10_00_10_00 }, - (32, false) => quote! { 0b11_01_11_01 }, - (64, true) => quote! { 0b00 }, - (64, false) => quote! { 0b11 }, - _ => unimplemented!(), - }; + let expr = match (vec_ty.scalar, vec_ty.n_bits(), scalar_bits) { + (ScalarType::Float, 128, _) => { + // 128-bit shuffle of floats or doubles; there are built-in SSE intrinsics for this + let suffix = op_suffix(vec_ty.scalar, scalar_bits, false); + let intrinsic = intrinsic_ident("shuffle", suffix, vec_ty.n_bits()); + + let mask = match (vec_ty.scalar_bits, select_even) { + (32, true) => quote! { 0b10_00_10_00 }, + (32, false) => quote! { 0b11_01_11_01 }, + (64, true) => quote! { 0b00 }, + (64, false) => quote! { 0b11 }, + _ => unimplemented!(), + }; - quote! { unsafe { #intrinsic::<#mask>(a.into(), b.into()).simd_into(self) } } - } else { - match vec_ty.scalar_bits { - 32 => { - let op = if select_even { "lo" } else { "hi" }; + quote! { unsafe { #intrinsic::<#mask>(a.into(), b.into()).simd_into(self) } } + } + (ScalarType::Int | ScalarType::Mask | ScalarType::Unsigned, 128, 32) => { + // 128-bit shuffle of 32-bit integers; unlike with floats, there is no single shuffle instruction that + // combines two vectors + let op = if select_even { "unpacklo" } else { "unpackhi" }; + let intrinsic = intrinsic_ident(op, "epi64", vec_ty.n_bits()); + + quote! { + unsafe { + let t1 = _mm_shuffle_epi32::<0b11_01_10_00>(a.into()); + let t2 = _mm_shuffle_epi32::<0b11_01_10_00>(b.into()); + #intrinsic(t1, t2).simd_into(self) + } + } + } + (ScalarType::Int | ScalarType::Mask | ScalarType::Unsigned, 128, 16 | 8) => { + // Separate out the even-indexed and odd-indexed elements + let mask = match scalar_bits { + 8 => { + quote! { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 } + } + 16 => { + quote! { 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 } + } + _ => unreachable!(), + }; + let mask_reg = match vec_ty.n_bits() { + 128 => quote! { _mm_setr_epi8(#mask) }, + 256 => quote! { _mm256_setr_epi8(#mask, #mask) }, + _ => unreachable!(), + }; + let shuffle_epi8 = intrinsic_ident("shuffle", "epi8", vec_ty.n_bits()); - let intrinsic = format_ident!("_mm_unpack{op}_epi64"); + // Select either the low or high half of each one + let op = if select_even { "unpacklo" } else { "unpackhi" }; + let unpack_epi64 = intrinsic_ident(op, "epi64", vec_ty.n_bits()); - quote! { - unsafe { - let t1 = _mm_shuffle_epi32::<0b11_01_10_00>(a.into()); - let t2 = _mm_shuffle_epi32::<0b11_01_10_00>(b.into()); - #intrinsic(t1, t2).simd_into(self) - } + quote! { + unsafe { + let mask = #mask_reg; + + let t1 = #shuffle_epi8(a.into(), mask); + let t2 = #shuffle_epi8(b.into(), mask); + #unpack_epi64(t1, t2).simd_into(self) } } - 16 | 8 => { - let mask = match (scalar_bits, select_even) { - (8, true) => { - quote! { 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14 } - } - (8, false) => { - quote! { 1, 3, 5, 7, 9, 11, 13, 15, 1, 3, 5, 7, 9, 11, 13, 15 } - } - (16, true) => { - quote! { 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13 } - } - (16, false) => { - quote! { 2, 3, 6, 7, 10, 11, 14, 15, 2, 3, 6, 7, 10, 11, 14, 15 } - } - _ => unreachable!(), - }; + } + (_, 256, 64 | 32) => { + // First we perform a lane-crossing shuffle to move the even-indexed elements of each input to the lower + // half, and the odd-indexed ones to the upper half. + // e.g. [0, 1, 2, 3, 4, 5, 6, 7] becomes [0, 2, 4, 6, 1, 3, 5, 7]). + let low_shuffle_kind = match scalar_bits { + 32 => "permutevar8x32", + 64 => "permute4x64", + _ => unreachable!(), + }; + let low_shuffle_suffix = op_suffix(vec_ty.scalar, scalar_bits, false); + let low_shuffle_intrinsic = intrinsic_ident(low_shuffle_kind, low_shuffle_suffix, 256); + let low_shuffle = |input_name: TokenStream| match scalar_bits { + 32 => { + quote! { #low_shuffle_intrinsic(#input_name, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)) } + } + 64 => quote! { #low_shuffle_intrinsic::<0b11_01_10_00>(#input_name) }, + _ => unreachable!(), + }; + let shuf_t1 = low_shuffle(quote! { a.into() }); + let shuf_t2 = low_shuffle(quote! { b.into() }); + + // Then we combine the lower or upper halves. + let high_shuffle = intrinsic_ident( + match vec_ty.scalar { + ScalarType::Float => "permute2f128", + _ => "permute2x128", + }, + coarse_type(*vec_ty), + 256, + ); + let high_shuffle_immediate = if select_even { + quote! { 0b0010_0000 } + } else { + quote! { 0b0011_0001 } + }; - quote! { - unsafe { - let mask = _mm_setr_epi8(#mask); + quote! { + unsafe { + let t1 = #shuf_t1; + let t2 = #shuf_t2; - let t1 = _mm_shuffle_epi8(a.into(), mask); - let t2 = _mm_shuffle_epi8(b.into(), mask); - _mm_unpacklo_epi64(t1, t2).simd_into(self) - } + #high_shuffle::<#high_shuffle_immediate>(t1, t2).simd_into(self) + } + } + } + (_, 256, 16 | 8) => { + // Separate out the even-indexed and odd-indexed elements within each 128-bit lane + let mask = match scalar_bits { + 8 => { + quote! { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 } + } + 16 => { + quote! { 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 } + } + _ => unreachable!(), + }; + + // We then permute the even-indexed and odd-indexed blocks across lanes, and finally do a 2x128 permute to + // select either the even- or odd-indexed elements + let high_shuffle_immediate = if select_even { + quote! { 0b0010_0000 } + } else { + quote! { 0b0011_0001 } + }; + + quote! { + unsafe { + let mask = _mm256_setr_epi8(#mask, #mask); + let a_shuffled = _mm256_shuffle_epi8(a.into(), mask); + let b_shuffled = _mm256_shuffle_epi8(b.into(), mask); + + let packed = _mm256_permute2x128_si256::<#high_shuffle_immediate>( + _mm256_permute4x64_epi64::<0b11_01_10_00>(a_shuffled), + _mm256_permute4x64_epi64::<0b11_01_10_00>(b_shuffled) + ); + + packed.simd_into(self) } } - _ => quote! { todo!() }, } + _ => unimplemented!(), }; quote! { diff --git a/fearless_simd_gen/src/x86_common.rs b/fearless_simd_gen/src/x86_common.rs index 2673dcf0..ab38097c 100644 --- a/fearless_simd_gen/src/x86_common.rs +++ b/fearless_simd_gen/src/x86_common.rs @@ -26,16 +26,15 @@ pub(crate) fn op_suffix(mut ty: ScalarType, bits: usize, sign_aware: bool) -> &' } } -pub(crate) fn set0_intrinsic(vec_ty: VecType) -> Ident { +/// Intrinsic name for the "int, float, or double" type (not as fine-grained as [`op_suffix`]). +pub(crate) fn coarse_type(vec_ty: VecType) -> &'static str { use ScalarType::*; - let suffix = match (vec_ty.scalar, vec_ty.n_bits()) { + match (vec_ty.scalar, vec_ty.n_bits()) { (Int | Unsigned | Mask, 128) => "si128", (Int | Unsigned | Mask, 256) => "si256", (Int | Unsigned | Mask, 512) => "si512", _ => op_suffix(vec_ty.scalar, vec_ty.scalar_bits, false), - }; - - intrinsic_ident("setzero", suffix, vec_ty.n_bits()) + } } pub(crate) fn set1_intrinsic(ty: ScalarType, bits: usize, ty_bits: usize) -> Ident { @@ -117,3 +116,29 @@ pub(crate) fn intrinsic_ident(name: &str, suffix: &str, ty_bits: usize) -> Ident format_ident!("_mm{prefix}_{name}_{suffix}") } + +pub(crate) fn cast_ident( + src_scalar_ty: ScalarType, + dst_scalar_ty: ScalarType, + scalar_bits: usize, + ty_bits: usize, +) -> Ident { + let prefix = match ty_bits { + 128 => "", + 256 => "256", + 512 => "512", + _ => unreachable!(), + }; + let src_name = coarse_type(VecType::new( + src_scalar_ty, + scalar_bits, + ty_bits / scalar_bits, + )); + let dst_name = coarse_type(VecType::new( + dst_scalar_ty, + scalar_bits, + ty_bits / scalar_bits, + )); + + format_ident!("_mm{prefix}_cast{src_name}_{dst_name}") +} From 97a0494406d51b0a070984963e3c3e1794a349aa Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 18:29:48 -0500 Subject: [PATCH 03/11] Consolidate the x86 codegen --- fearless_simd_gen/src/arch/avx2.rs | 18 -- fearless_simd_gen/src/arch/mod.rs | 5 +- fearless_simd_gen/src/arch/sse4_2.rs | 18 -- fearless_simd_gen/src/arch/x86.rs | 275 +++++++++++++++++++++++ fearless_simd_gen/src/arch/x86_common.rs | 129 ----------- fearless_simd_gen/src/main.rs | 1 - fearless_simd_gen/src/mk_avx2.rs | 8 +- fearless_simd_gen/src/mk_sse4_2.rs | 19 +- fearless_simd_gen/src/x86_common.rs | 144 ------------ 9 files changed, 288 insertions(+), 329 deletions(-) delete mode 100644 fearless_simd_gen/src/arch/avx2.rs delete mode 100644 fearless_simd_gen/src/arch/sse4_2.rs create mode 100644 fearless_simd_gen/src/arch/x86.rs delete mode 100644 fearless_simd_gen/src/arch/x86_common.rs delete mode 100644 fearless_simd_gen/src/x86_common.rs diff --git a/fearless_simd_gen/src/arch/avx2.rs b/fearless_simd_gen/src/arch/avx2.rs deleted file mode 100644 index d82585ea..00000000 --- a/fearless_simd_gen/src/arch/avx2.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2025 the Fearless_SIMD Authors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use crate::arch::{Arch, x86_common}; -use crate::types::VecType; -use proc_macro2::TokenStream; - -pub(crate) struct Avx2; - -impl Arch for Avx2 { - fn arch_ty(&self, ty: &VecType) -> TokenStream { - x86_common::arch_ty(ty) - } - - fn expr(&self, op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { - x86_common::expr(op, ty, args) - } -} diff --git a/fearless_simd_gen/src/arch/mod.rs b/fearless_simd_gen/src/arch/mod.rs index 97c928d4..a1d2764b 100644 --- a/fearless_simd_gen/src/arch/mod.rs +++ b/fearless_simd_gen/src/arch/mod.rs @@ -3,11 +3,8 @@ pub(crate) mod fallback; pub(crate) mod neon; - -pub(crate) mod avx2; -pub(crate) mod sse4_2; pub(crate) mod wasm; -pub(crate) mod x86_common; +pub(crate) mod x86; use proc_macro2::TokenStream; diff --git a/fearless_simd_gen/src/arch/sse4_2.rs b/fearless_simd_gen/src/arch/sse4_2.rs deleted file mode 100644 index a62803de..00000000 --- a/fearless_simd_gen/src/arch/sse4_2.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2025 the Fearless_SIMD Authors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use crate::arch::{Arch, x86_common}; -use crate::types::VecType; -use proc_macro2::TokenStream; - -pub(crate) struct Sse4_2; - -impl Arch for Sse4_2 { - fn arch_ty(&self, ty: &VecType) -> TokenStream { - x86_common::arch_ty(ty) - } - - fn expr(&self, op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { - x86_common::expr(op, ty, args) - } -} diff --git a/fearless_simd_gen/src/arch/x86.rs b/fearless_simd_gen/src/arch/x86.rs new file mode 100644 index 00000000..dcaca2b4 --- /dev/null +++ b/fearless_simd_gen/src/arch/x86.rs @@ -0,0 +1,275 @@ +// Copyright 2025 the Fearless_SIMD Authors +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#![expect( + unreachable_pub, + reason = "TODO: https://github.com/linebender/fearless_simd/issues/40" +)] + +use crate::arch::Arch; +use crate::types::{ScalarType, VecType}; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{format_ident, quote}; + +pub struct X86; + +pub(crate) fn translate_op(op: &str) -> Option<&'static str> { + Some(match op { + "floor" => "floor", + "sqrt" => "sqrt", + "add" => "add", + "sub" => "sub", + "div" => "div", + "and" => "and", + "simd_eq" => "cmpeq", + "simd_lt" => "cmplt", + "simd_le" => "cmple", + "simd_ge" => "cmpge", + "simd_gt" => "cmpgt", + "or" => "or", + "xor" => "xor", + "shl" => "shl", + "shr" => "shr", + "max" => "max", + "min" => "min", + "max_precise" => "max", + "min_precise" => "min", + "select" => "blendv", + _ => return None, + }) +} + +impl Arch for X86 { + fn arch_ty(&self, ty: &VecType) -> TokenStream { + let suffix = match (ty.scalar, ty.scalar_bits) { + (ScalarType::Float, 32) => "", + (ScalarType::Float, 64) => "d", + (ScalarType::Float, _) => unimplemented!(), + (ScalarType::Unsigned | ScalarType::Int | ScalarType::Mask, _) => "i", + }; + let name = format!("__m{}{}", ty.scalar_bits * ty.len, suffix); + let ident = Ident::new(&name, Span::call_site()); + quote! { #ident } + } + + fn expr(&self, op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { + if let Some(op_name) = translate_op(op) { + let sign_aware = matches!(op, "max" | "min"); + + let suffix = match op_name { + "and" | "or" | "xor" => coarse_type(*ty), + "blendv" if ty.scalar != ScalarType::Float => "epi8", + _ => op_suffix(ty.scalar, ty.scalar_bits, sign_aware), + }; + let intrinsic = intrinsic_ident(op_name, suffix, ty.n_bits()); + quote! { #intrinsic ( #( #args ),* ) } + } else { + let suffix = op_suffix(ty.scalar, ty.scalar_bits, true); + match op { + "trunc" => { + let intrinsic = intrinsic_ident("round", suffix, ty.n_bits()); + quote! { #intrinsic ( #( #args, )* _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) } + } + "neg" => match ty.scalar { + ScalarType::Float => { + let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); + let xor = + simple_intrinsic("xor", ScalarType::Float, ty.scalar_bits, ty.n_bits()); + quote! { + #( #xor(#args, #set1(-0.0)) )* + } + } + ScalarType::Int => { + let set0 = intrinsic_ident("setzero", coarse_type(*ty), ty.n_bits()); + let sub = simple_intrinsic("sub", ty.scalar, ty.scalar_bits, ty.n_bits()); + let arg = &args[0]; + quote! { + #sub(#set0(), #arg) + } + } + _ => unreachable!(), + }, + "abs" => { + let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); + let andnot = + simple_intrinsic("andnot", ScalarType::Float, ty.scalar_bits, ty.n_bits()); + quote! { + #( #andnot(#set1(-0.0), #args) )* + } + } + "copysign" => { + let a = &args[0]; + let b = &args[1]; + let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); + let and = + simple_intrinsic("and", ScalarType::Float, ty.scalar_bits, ty.n_bits()); + let andnot = + simple_intrinsic("andnot", ScalarType::Float, ty.scalar_bits, ty.n_bits()); + let or = simple_intrinsic("or", ScalarType::Float, ty.scalar_bits, ty.n_bits()); + quote! { + let mask = #set1(-0.0); + #or(#and(mask, #b), #andnot(mask, #a)) + } + } + "mul" => { + let suffix = op_suffix(ty.scalar, ty.scalar_bits, false); + let intrinsic = if matches!(ty.scalar, ScalarType::Int | ScalarType::Unsigned) { + intrinsic_ident("mullo", suffix, ty.n_bits()) + } else { + intrinsic_ident("mul", suffix, ty.n_bits()) + }; + + quote! { #intrinsic ( #( #args ),* ) } + } + "shrv" if ty.scalar_bits > 16 => { + let suffix = op_suffix(ty.scalar, ty.scalar_bits, false); + let name = match ty.scalar { + ScalarType::Int => "srav", + _ => "srlv", + }; + let intrinsic = intrinsic_ident(name, suffix, ty.n_bits()); + quote! { #intrinsic ( #( #args ),* ) } + } + _ => unimplemented!("{}", op), + } + } + } +} + +pub(crate) fn op_suffix(mut ty: ScalarType, bits: usize, sign_aware: bool) -> &'static str { + use ScalarType::*; + if !sign_aware && ty == Unsigned { + ty = Int; + } + match (ty, bits) { + (Float, 32) => "ps", + (Float, 64) => "pd", + (Float, _) => unimplemented!("{bits} bit floats"), + (Int | Mask, 8) => "epi8", + (Int | Mask, 16) => "epi16", + (Int | Mask, 32) => "epi32", + (Int | Mask, 64) => "epi64", + (Unsigned, 8) => "epu8", + (Unsigned, 16) => "epu16", + (Unsigned, 32) => "epu32", + (Unsigned, 64) => "epu64", + _ => unreachable!(), + } +} + +/// Intrinsic name for the "int, float, or double" type (not as fine-grained as [`op_suffix`]). +pub(crate) fn coarse_type(vec_ty: VecType) -> &'static str { + use ScalarType::*; + match (vec_ty.scalar, vec_ty.n_bits()) { + (Int | Unsigned | Mask, 128) => "si128", + (Int | Unsigned | Mask, 256) => "si256", + (Int | Unsigned | Mask, 512) => "si512", + _ => op_suffix(vec_ty.scalar, vec_ty.scalar_bits, false), + } +} + +pub(crate) fn set1_intrinsic(ty: ScalarType, bits: usize, ty_bits: usize) -> Ident { + use ScalarType::*; + let suffix = match (ty, bits) { + (Int | Unsigned | Mask, 64) => "epi64x", + _ => op_suffix(ty, bits, false), + }; + + intrinsic_ident("set1", suffix, ty_bits) +} + +pub(crate) fn simple_intrinsic(name: &str, ty: ScalarType, bits: usize, ty_bits: usize) -> Ident { + let suffix = op_suffix(ty, bits, true); + + intrinsic_ident(name, suffix, ty_bits) +} + +pub(crate) fn simple_sign_unaware_intrinsic( + name: &str, + ty: ScalarType, + bits: usize, + ty_bits: usize, +) -> Ident { + let suffix = op_suffix(ty, bits, false); + + intrinsic_ident(name, suffix, ty_bits) +} + +pub(crate) fn extend_intrinsic( + ty: ScalarType, + from_bits: usize, + to_bits: usize, + ty_bits: usize, +) -> Ident { + let from_suffix = op_suffix(ty, from_bits, true); + let to_suffix = op_suffix(ty, to_bits, false); + + intrinsic_ident(&format!("cvt{from_suffix}"), to_suffix, ty_bits) +} + +pub(crate) fn cvt_intrinsic(from: VecType, to: VecType) -> Ident { + let from_suffix = op_suffix(from.scalar, from.scalar_bits, false); + let to_suffix = op_suffix(to.scalar, to.scalar_bits, false); + + intrinsic_ident(&format!("cvt{from_suffix}"), to_suffix, from.n_bits()) +} + +pub(crate) fn pack_intrinsic(from_bits: usize, signed: bool, ty_bits: usize) -> Ident { + let unsigned = match signed { + true => "", + false => "u", + }; + let suffix = op_suffix(ScalarType::Int, from_bits, false); + + intrinsic_ident(&format!("pack{unsigned}s"), suffix, ty_bits) +} + +pub(crate) fn unpack_intrinsic( + scalar_type: ScalarType, + scalar_bits: usize, + low: bool, + ty_bits: usize, +) -> Ident { + let suffix = op_suffix(scalar_type, scalar_bits, false); + + let low_pref = if low { "lo" } else { "hi" }; + + intrinsic_ident(&format!("unpack{low_pref}"), suffix, ty_bits) +} + +pub(crate) fn intrinsic_ident(name: &str, suffix: &str, ty_bits: usize) -> Ident { + let prefix = match ty_bits { + 128 => "", + 256 => "256", + 512 => "512", + _ => unreachable!(), + }; + + format_ident!("_mm{prefix}_{name}_{suffix}") +} + +pub(crate) fn cast_ident( + src_scalar_ty: ScalarType, + dst_scalar_ty: ScalarType, + scalar_bits: usize, + ty_bits: usize, +) -> Ident { + let prefix = match ty_bits { + 128 => "", + 256 => "256", + 512 => "512", + _ => unreachable!(), + }; + let src_name = coarse_type(VecType::new( + src_scalar_ty, + scalar_bits, + ty_bits / scalar_bits, + )); + let dst_name = coarse_type(VecType::new( + dst_scalar_ty, + scalar_bits, + ty_bits / scalar_bits, + )); + + format_ident!("_mm{prefix}_cast{src_name}_{dst_name}") +} diff --git a/fearless_simd_gen/src/arch/x86_common.rs b/fearless_simd_gen/src/arch/x86_common.rs deleted file mode 100644 index 615cc77d..00000000 --- a/fearless_simd_gen/src/arch/x86_common.rs +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2025 the Fearless_SIMD Authors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use crate::types::{ScalarType, VecType}; -use crate::x86_common::{ - coarse_type, intrinsic_ident, op_suffix, set1_intrinsic, simple_intrinsic, -}; -use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; - -pub(crate) fn translate_op(op: &str) -> Option<&'static str> { - Some(match op { - "floor" => "floor", - "sqrt" => "sqrt", - "add" => "add", - "sub" => "sub", - "div" => "div", - "and" => "and", - "simd_eq" => "cmpeq", - "simd_lt" => "cmplt", - "simd_le" => "cmple", - "simd_ge" => "cmpge", - "simd_gt" => "cmpgt", - "or" => "or", - "xor" => "xor", - "shl" => "shl", - "shr" => "shr", - "max" => "max", - "min" => "min", - "max_precise" => "max", - "min_precise" => "min", - "select" => "blendv", - _ => return None, - }) -} - -pub(crate) fn arch_ty(ty: &VecType) -> TokenStream { - let suffix = match (ty.scalar, ty.scalar_bits) { - (ScalarType::Float, 32) => "", - (ScalarType::Float, 64) => "d", - (ScalarType::Float, _) => unimplemented!(), - (ScalarType::Unsigned | ScalarType::Int | ScalarType::Mask, _) => "i", - }; - let name = format!("__m{}{}", ty.scalar_bits * ty.len, suffix); - let ident = Ident::new(&name, Span::call_site()); - quote! { #ident } -} - -pub(crate) fn expr(op: &str, ty: &VecType, args: &[TokenStream]) -> TokenStream { - if let Some(op_name) = translate_op(op) { - let sign_aware = matches!(op, "max" | "min"); - - let suffix = match op_name { - "and" | "or" | "xor" => coarse_type(*ty), - "blendv" if ty.scalar != ScalarType::Float => "epi8", - _ => op_suffix(ty.scalar, ty.scalar_bits, sign_aware), - }; - let intrinsic = intrinsic_ident(op_name, suffix, ty.n_bits()); - quote! { #intrinsic ( #( #args ),* ) } - } else { - let suffix = op_suffix(ty.scalar, ty.scalar_bits, true); - match op { - "trunc" => { - let intrinsic = intrinsic_ident("round", suffix, ty.n_bits()); - quote! { #intrinsic ( #( #args, )* _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) } - } - "neg" => match ty.scalar { - ScalarType::Float => { - let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); - let xor = - simple_intrinsic("xor", ScalarType::Float, ty.scalar_bits, ty.n_bits()); - quote! { - #( #xor(#args, #set1(-0.0)) )* - } - } - ScalarType::Int => { - let set0 = intrinsic_ident("setzero", coarse_type(*ty), ty.n_bits()); - let sub = simple_intrinsic("sub", ty.scalar, ty.scalar_bits, ty.n_bits()); - let arg = &args[0]; - quote! { - #sub(#set0(), #arg) - } - } - _ => unreachable!(), - }, - "abs" => { - let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); - let andnot = - simple_intrinsic("andnot", ScalarType::Float, ty.scalar_bits, ty.n_bits()); - quote! { - #( #andnot(#set1(-0.0), #args) )* - } - } - "copysign" => { - let a = &args[0]; - let b = &args[1]; - let set1 = set1_intrinsic(ty.scalar, ty.scalar_bits, ty.n_bits()); - let and = simple_intrinsic("and", ScalarType::Float, ty.scalar_bits, ty.n_bits()); - let andnot = - simple_intrinsic("andnot", ScalarType::Float, ty.scalar_bits, ty.n_bits()); - let or = simple_intrinsic("or", ScalarType::Float, ty.scalar_bits, ty.n_bits()); - quote! { - let mask = #set1(-0.0); - #or(#and(mask, #b), #andnot(mask, #a)) - } - } - "mul" => { - let suffix = op_suffix(ty.scalar, ty.scalar_bits, false); - let intrinsic = if matches!(ty.scalar, ScalarType::Int | ScalarType::Unsigned) { - intrinsic_ident("mullo", suffix, ty.n_bits()) - } else { - intrinsic_ident("mul", suffix, ty.n_bits()) - }; - - quote! { #intrinsic ( #( #args ),* ) } - } - "shrv" if ty.scalar_bits > 16 => { - let suffix = op_suffix(ty.scalar, ty.scalar_bits, false); - let name = match ty.scalar { - ScalarType::Int => "srav", - _ => "srlv", - }; - let intrinsic = intrinsic_ident(name, suffix, ty.n_bits()); - quote! { #intrinsic ( #( #args ),* ) } - } - _ => unimplemented!("{}", op), - } - } -} diff --git a/fearless_simd_gen/src/main.rs b/fearless_simd_gen/src/main.rs index 3aab3b38..fad1187c 100644 --- a/fearless_simd_gen/src/main.rs +++ b/fearless_simd_gen/src/main.rs @@ -24,7 +24,6 @@ mod mk_sse4_2; mod mk_wasm; mod ops; mod types; -mod x86_common; #[derive(Clone, Copy, ValueEnum, Debug)] enum Module { diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index a205b382..4b905d47 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -2,13 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::Arch; -use crate::arch::avx2::Avx2; -use crate::arch::sse4_2::Sse4_2; +use crate::arch::x86::{X86, cast_ident, simple_intrinsic}; use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; use crate::mk_sse4_2; use crate::ops::{OpSig, TyFlavor, ops_for_type}; use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; -use crate::x86_common::{cast_ident, simple_intrinsic}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -92,7 +90,7 @@ fn mk_simd_impl() -> TokenStream { continue; } - let method = make_method(method, sig, vec_ty, Sse4_2, vec_ty.n_bits()); + let method = make_method(method, sig, vec_ty, X86, vec_ty.n_bits()); methods.push(method); } @@ -140,7 +138,7 @@ fn mk_type_impl() -> TokenStream { continue; } let simd = ty.rust(); - let arch = Avx2.arch_ty(ty); + let arch = X86.arch_ty(ty); result.push(quote! { impl SimdFrom<#arch, S> for #simd { #[inline(always)] diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index ac2aed6f..db158a31 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -2,15 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::Arch; -use crate::arch::sse4_2::Sse4_2; -use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; -use crate::ops::{OpSig, TyFlavor, ops_for_type, reinterpret_ty, valid_reinterpret}; -use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; -use crate::x86_common::{ - cast_ident, coarse_type, cvt_intrinsic, extend_intrinsic, intrinsic_ident, op_suffix, +use crate::arch::x86::{ + X86, cast_ident, coarse_type, cvt_intrinsic, extend_intrinsic, intrinsic_ident, op_suffix, pack_intrinsic, set1_intrinsic, simple_intrinsic, simple_sign_unaware_intrinsic, unpack_intrinsic, }; +use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; +use crate::ops::{OpSig, TyFlavor, ops_for_type, reinterpret_ty, valid_reinterpret}; +use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; @@ -95,7 +94,7 @@ fn mk_simd_impl() -> TokenStream { continue; } - let method = make_method(method, sig, vec_ty, Sse4_2, 128); + let method = make_method(method, sig, vec_ty, X86, 128); methods.push(method); } @@ -147,7 +146,7 @@ fn mk_type_impl() -> TokenStream { continue; } let simd = ty.rust(); - let arch = Sse4_2.arch_ty(ty); + let arch = X86.arch_ty(ty); result.push(quote! { impl SimdFrom<#arch, S> for #simd { #[inline(always)] @@ -540,7 +539,7 @@ pub(crate) fn handle_ternary( quote! { c.into() }, ]; - let expr = Sse4_2.expr(method, vec_ty, &args); + let expr = X86.expr(method, vec_ty, &args); quote! { #method_sig { #expr.simd_into(self) @@ -572,7 +571,7 @@ pub(crate) fn handle_select( _ => quote! { a.into() }, }, ]; - let expr = Sse4_2.expr("select", vec_ty, &args); + let expr = X86.expr("select", vec_ty, &args); quote! { #method_sig { diff --git a/fearless_simd_gen/src/x86_common.rs b/fearless_simd_gen/src/x86_common.rs deleted file mode 100644 index ab38097c..00000000 --- a/fearless_simd_gen/src/x86_common.rs +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2025 the Fearless_SIMD Authors -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use crate::types::{ScalarType, VecType}; -use proc_macro2::Ident; -use quote::format_ident; - -pub(crate) fn op_suffix(mut ty: ScalarType, bits: usize, sign_aware: bool) -> &'static str { - use ScalarType::*; - if !sign_aware && ty == Unsigned { - ty = Int; - } - match (ty, bits) { - (Float, 32) => "ps", - (Float, 64) => "pd", - (Float, _) => unimplemented!("{bits} bit floats"), - (Int | Mask, 8) => "epi8", - (Int | Mask, 16) => "epi16", - (Int | Mask, 32) => "epi32", - (Int | Mask, 64) => "epi64", - (Unsigned, 8) => "epu8", - (Unsigned, 16) => "epu16", - (Unsigned, 32) => "epu32", - (Unsigned, 64) => "epu64", - _ => unreachable!(), - } -} - -/// Intrinsic name for the "int, float, or double" type (not as fine-grained as [`op_suffix`]). -pub(crate) fn coarse_type(vec_ty: VecType) -> &'static str { - use ScalarType::*; - match (vec_ty.scalar, vec_ty.n_bits()) { - (Int | Unsigned | Mask, 128) => "si128", - (Int | Unsigned | Mask, 256) => "si256", - (Int | Unsigned | Mask, 512) => "si512", - _ => op_suffix(vec_ty.scalar, vec_ty.scalar_bits, false), - } -} - -pub(crate) fn set1_intrinsic(ty: ScalarType, bits: usize, ty_bits: usize) -> Ident { - use ScalarType::*; - let suffix = match (ty, bits) { - (Int | Unsigned | Mask, 64) => "epi64x", - _ => op_suffix(ty, bits, false), - }; - - intrinsic_ident("set1", suffix, ty_bits) -} - -pub(crate) fn simple_intrinsic(name: &str, ty: ScalarType, bits: usize, ty_bits: usize) -> Ident { - let suffix = op_suffix(ty, bits, true); - - intrinsic_ident(name, suffix, ty_bits) -} - -pub(crate) fn simple_sign_unaware_intrinsic( - name: &str, - ty: ScalarType, - bits: usize, - ty_bits: usize, -) -> Ident { - let suffix = op_suffix(ty, bits, false); - - intrinsic_ident(name, suffix, ty_bits) -} - -pub(crate) fn extend_intrinsic( - ty: ScalarType, - from_bits: usize, - to_bits: usize, - ty_bits: usize, -) -> Ident { - let from_suffix = op_suffix(ty, from_bits, true); - let to_suffix = op_suffix(ty, to_bits, false); - - intrinsic_ident(&format!("cvt{from_suffix}"), to_suffix, ty_bits) -} - -pub(crate) fn cvt_intrinsic(from: VecType, to: VecType) -> Ident { - let from_suffix = op_suffix(from.scalar, from.scalar_bits, false); - let to_suffix = op_suffix(to.scalar, to.scalar_bits, false); - - intrinsic_ident(&format!("cvt{from_suffix}"), to_suffix, from.n_bits()) -} - -pub(crate) fn pack_intrinsic(from_bits: usize, signed: bool, ty_bits: usize) -> Ident { - let unsigned = match signed { - true => "", - false => "u", - }; - let suffix = op_suffix(ScalarType::Int, from_bits, false); - - intrinsic_ident(&format!("pack{unsigned}s"), suffix, ty_bits) -} - -pub(crate) fn unpack_intrinsic( - scalar_type: ScalarType, - scalar_bits: usize, - low: bool, - ty_bits: usize, -) -> Ident { - let suffix = op_suffix(scalar_type, scalar_bits, false); - - let low_pref = if low { "lo" } else { "hi" }; - - intrinsic_ident(&format!("unpack{low_pref}"), suffix, ty_bits) -} - -pub(crate) fn intrinsic_ident(name: &str, suffix: &str, ty_bits: usize) -> Ident { - let prefix = match ty_bits { - 128 => "", - 256 => "256", - 512 => "512", - _ => unreachable!(), - }; - - format_ident!("_mm{prefix}_{name}_{suffix}") -} - -pub(crate) fn cast_ident( - src_scalar_ty: ScalarType, - dst_scalar_ty: ScalarType, - scalar_bits: usize, - ty_bits: usize, -) -> Ident { - let prefix = match ty_bits { - 128 => "", - 256 => "256", - 512 => "512", - _ => unreachable!(), - }; - let src_name = coarse_type(VecType::new( - src_scalar_ty, - scalar_bits, - ty_bits / scalar_bits, - )); - let dst_name = coarse_type(VecType::new( - dst_scalar_ty, - scalar_bits, - ty_bits / scalar_bits, - )); - - format_ident!("_mm{prefix}_cast{src_name}_{dst_name}") -} From c4f15a134cc9062ba677d9b23447847e0186817f Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 19:43:53 -0500 Subject: [PATCH 04/11] Add more split/combine tests --- fearless_simd_tests/tests/harness/mod.rs | 231 +++++++++++++++++++++++ 1 file changed, 231 insertions(+) diff --git a/fearless_simd_tests/tests/harness/mod.rs b/fearless_simd_tests/tests/harness/mod.rs index 3e0df758..9280e113 100644 --- a/fearless_simd_tests/tests/harness/mod.rs +++ b/fearless_simd_tests/tests/harness/mod.rs @@ -1952,6 +1952,237 @@ fn combine_u32x4(simd: S) { assert_eq!(a.combine(b).val, [1, 2, 3, 4, 5, 6, 7, 8]); } +#[simd_test] +fn combine_f32x8(simd: S) { + let a = f32x8::from_slice(simd, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + let b = f32x8::from_slice(simd, &[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + assert_eq!( + a.combine(b).val, + [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0 + ] + ); +} + +#[simd_test] +fn combine_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let b = i8x32::from_slice( + simd, + &[ + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + ); +} + +#[simd_test] +fn combine_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + let b = u8x32::from_slice( + simd, + &[ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + ); +} + +#[simd_test] +fn combine_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let b = i16x16::from_slice( + simd, + &[ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32 + ] + ); +} + +#[simd_test] +fn combine_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + ); + let b = u16x16::from_slice( + simd, + &[ + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + assert_eq!( + a.combine(b).val, + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31 + ] + ); +} + +#[simd_test] +fn combine_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let b = i32x8::from_slice(simd, &[9, 10, 11, 12, 13, 14, 15, 16]); + assert_eq!( + a.combine(b).val, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); +} + +#[simd_test] +fn combine_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[0, 1, 2, 3, 4, 5, 6, 7]); + let b = u32x8::from_slice(simd, &[8, 9, 10, 11, 12, 13, 14, 15]); + assert_eq!( + a.combine(b).val, + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ); +} + +#[simd_test] +fn combine_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let b = f64x4::from_slice(simd, &[5.0, 6.0, 7.0, 8.0]); + assert_eq!(a.combine(b).val, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); +} + +#[simd_test] +fn split_f32x8(simd: S) { + let a = f32x8::from_slice(simd, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + let (lo, hi) = simd.split_f32x8(a); + assert_eq!(lo.val, [1.0, 2.0, 3.0, 4.0]); + assert_eq!(hi.val, [5.0, 6.0, 7.0, 8.0]); +} + +#[simd_test] +fn split_i8x32(simd: S) { + let a = i8x32::from_slice( + simd, + &[ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ], + ); + let (lo, hi) = simd.split_i8x32(a); + assert_eq!( + lo.val, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); + assert_eq!( + hi.val, + [ + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ] + ); +} + +#[simd_test] +fn split_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + let (lo, hi) = simd.split_u8x32(a); + assert_eq!( + lo.val, + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ); + assert_eq!( + hi.val, + [ + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ] + ); +} + +#[simd_test] +fn split_i16x16(simd: S) { + let a = i16x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ); + let (lo, hi) = simd.split_i16x16(a); + assert_eq!(lo.val, [1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(hi.val, [9, 10, 11, 12, 13, 14, 15, 16]); +} + +#[simd_test] +fn split_u16x16(simd: S) { + let a = u16x16::from_slice( + simd, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + ); + let (lo, hi) = simd.split_u16x16(a); + assert_eq!(lo.val, [0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(hi.val, [8, 9, 10, 11, 12, 13, 14, 15]); +} + +#[simd_test] +fn split_i32x8(simd: S) { + let a = i32x8::from_slice(simd, &[1, 2, 3, 4, 5, 6, 7, 8]); + let (lo, hi) = simd.split_i32x8(a); + assert_eq!(lo.val, [1, 2, 3, 4]); + assert_eq!(hi.val, [5, 6, 7, 8]); +} + +#[simd_test] +fn split_u32x8(simd: S) { + let a = u32x8::from_slice(simd, &[0, 1, 2, 3, 4, 5, 6, 7]); + let (lo, hi) = simd.split_u32x8(a); + assert_eq!(lo.val, [0, 1, 2, 3]); + assert_eq!(hi.val, [4, 5, 6, 7]); +} + +#[simd_test] +fn split_f64x4(simd: S) { + let a = f64x4::from_slice(simd, &[1.0, 2.0, 3.0, 4.0]); + let (lo, hi) = simd.split_f64x4(a); + assert_eq!(lo.val, [1.0, 2.0]); + assert_eq!(hi.val, [3.0, 4.0]); +} + #[simd_test] fn select_f32x4(simd: S) { let mask = mask32x4::from_slice(simd, &[-1, 0, -1, 0]); From 0ca9b62a15a271b5c2b68df77afdddc267a674a9 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 19:37:30 -0500 Subject: [PATCH 05/11] Add specialized AVX2 split/combine ops --- fearless_simd/src/generated/avx2.rs | 192 ++++++++++++---------------- fearless_simd_gen/src/mk_avx2.rs | 48 ++++++- 2 files changed, 129 insertions(+), 111 deletions(-) diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 6ed464cd..44612fa2 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -180,10 +180,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_f32x4(self, a: f32x4, b: f32x4) -> f32x8 { - let mut result = [0.0; 8usize]; - result[0..4usize].copy_from_slice(&a.val); - result[4usize..8usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn reinterpret_f64_f32x4(self, a: f32x4) -> f64x2 { @@ -343,10 +340,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_i8x16(self, a: i8x16, b: i8x16) -> i8x32 { - let mut result = [0; 32usize]; - result[0..16usize].copy_from_slice(&a.val); - result[16usize..32usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn neg_i8x16(self, a: i8x16) -> i8x16 { @@ -496,10 +490,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_u8x16(self, a: u8x16, b: u8x16) -> u8x32 { - let mut result = [0; 32usize]; - result[0..16usize].copy_from_slice(&a.val); - result[16usize..32usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn widen_u8x16(self, a: u8x16) -> u16x16 { @@ -552,10 +543,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_mask8x16(self, a: mask8x16, b: mask8x16) -> mask8x32 { - let mut result = [0; 32usize]; - result[0..16usize].copy_from_slice(&a.val); - result[16usize..32usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn splat_i16x8(self, val: i16) -> i16x8 { @@ -661,10 +649,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_i16x8(self, a: i16x8, b: i16x8) -> i16x16 { - let mut result = [0; 16usize]; - result[0..8usize].copy_from_slice(&a.val); - result[8usize..16usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn neg_i16x8(self, a: i16x8) -> i16x8 { @@ -798,10 +783,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_u16x8(self, a: u16x8, b: u16x8) -> u16x16 { - let mut result = [0; 16usize]; - result[0..8usize].copy_from_slice(&a.val); - result[8usize..16usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_u16x8(self, a: u16x8) -> u8x16 { @@ -852,10 +834,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_mask16x8(self, a: mask16x8, b: mask16x8) -> mask16x16 { - let mut result = [0; 16usize]; - result[0..8usize].copy_from_slice(&a.val); - result[8usize..16usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn splat_i32x4(self, val: i32) -> i32x4 { @@ -959,10 +938,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_i32x4(self, a: i32x4, b: i32x4) -> i32x8 { - let mut result = [0; 8usize]; - result[0..4usize].copy_from_slice(&a.val); - result[4usize..8usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn neg_i32x4(self, a: i32x4) -> i32x4 { @@ -1098,10 +1074,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_u32x4(self, a: u32x4, b: u32x4) -> u32x8 { - let mut result = [0; 8usize]; - result[0..4usize].copy_from_slice(&a.val); - result[4usize..8usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_u32x4(self, a: u32x4) -> u8x16 { @@ -1149,10 +1122,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_mask32x4(self, a: mask32x4, b: mask32x4) -> mask32x8 { - let mut result = [0; 8usize]; - result[0..4usize].copy_from_slice(&a.val); - result[4usize..8usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn splat_f64x2(self, val: f64) -> f64x2 { @@ -1271,10 +1241,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_f64x2(self, a: f64x2, b: f64x2) -> f64x4 { - let mut result = [0.0; 4usize]; - result[0..2usize].copy_from_slice(&a.val); - result[2usize..4usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128d(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn reinterpret_f32_f64x2(self, a: f64x2) -> f32x4 { @@ -1318,10 +1285,7 @@ impl Simd for Avx2 { } #[inline(always)] fn combine_mask64x2(self, a: mask64x2, b: mask64x2) -> mask64x4 { - let mut result = [0; 4usize]; - result[0..2usize].copy_from_slice(&a.val); - result[2usize..4usize].copy_from_slice(&b.val); - result.simd_into(self) + unsafe { _mm256_setr_m128i(a.into(), b.into()).simd_into(self) } } #[inline(always)] fn splat_f32x8(self, val: f32) -> f32x8 { @@ -1469,11 +1433,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_f32x8(self, a: f32x8) -> (f32x4, f32x4) { - let mut b0 = [0.0; 4usize]; - let mut b1 = [0.0; 4usize]; - b0.copy_from_slice(&a.val[0..4usize]); - b1.copy_from_slice(&a.val[4usize..8usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extractf128_ps::<0>(a.into()).simd_into(self), + _mm256_extractf128_ps::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn reinterpret_f64_f32x8(self, a: f32x8) -> f64x4 { @@ -1666,11 +1631,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_i8x32(self, a: i8x32) -> (i8x16, i8x16) { - let mut b0 = [0; 16usize]; - let mut b1 = [0; 16usize]; - b0.copy_from_slice(&a.val[0..16usize]); - b1.copy_from_slice(&a.val[16usize..32usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn neg_i8x32(self, a: i8x32) -> i8x32 { @@ -1849,11 +1815,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_u8x32(self, a: u8x32) -> (u8x16, u8x16) { - let mut b0 = [0; 16usize]; - let mut b1 = [0; 16usize]; - b0.copy_from_slice(&a.val[0..16usize]); - b1.copy_from_slice(&a.val[16usize..32usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn widen_u8x32(self, a: u8x32) -> u16x32 { @@ -1909,11 +1876,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_mask8x32(self, a: mask8x32) -> (mask8x16, mask8x16) { - let mut b0 = [0; 16usize]; - let mut b1 = [0; 16usize]; - b0.copy_from_slice(&a.val[0..16usize]); - b1.copy_from_slice(&a.val[16usize..32usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn splat_i16x16(self, val: i16) -> i16x16 { @@ -2052,11 +2020,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_i16x16(self, a: i16x16) -> (i16x8, i16x8) { - let mut b0 = [0; 8usize]; - let mut b1 = [0; 8usize]; - b0.copy_from_slice(&a.val[0..8usize]); - b1.copy_from_slice(&a.val[8usize..16usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn neg_i16x16(self, a: i16x16) -> i16x16 { @@ -2223,11 +2192,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_u16x16(self, a: u16x16) -> (u16x8, u16x8) { - let mut b0 = [0; 8usize]; - let mut b1 = [0; 8usize]; - b0.copy_from_slice(&a.val[0..8usize]); - b1.copy_from_slice(&a.val[8usize..16usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn narrow_u16x16(self, a: u16x16) -> u8x16 { @@ -2296,11 +2266,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_mask16x16(self, a: mask16x16) -> (mask16x8, mask16x8) { - let mut b0 = [0; 8usize]; - let mut b1 = [0; 8usize]; - b0.copy_from_slice(&a.val[0..8usize]); - b1.copy_from_slice(&a.val[8usize..16usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn splat_i32x8(self, val: i32) -> i32x8 { @@ -2427,11 +2398,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_i32x8(self, a: i32x8) -> (i32x4, i32x4) { - let mut b0 = [0; 4usize]; - let mut b1 = [0; 4usize]; - b0.copy_from_slice(&a.val[0..4usize]); - b1.copy_from_slice(&a.val[4usize..8usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn neg_i32x8(self, a: i32x8) -> i32x8 { @@ -2590,11 +2562,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_u32x8(self, a: u32x8) -> (u32x4, u32x4) { - let mut b0 = [0; 4usize]; - let mut b1 = [0; 4usize]; - b0.copy_from_slice(&a.val[0..4usize]); - b1.copy_from_slice(&a.val[4usize..8usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn reinterpret_u8_u32x8(self, a: u32x8) -> u8x32 { @@ -2649,11 +2622,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_mask32x8(self, a: mask32x8) -> (mask32x4, mask32x4) { - let mut b0 = [0; 4usize]; - let mut b1 = [0; 4usize]; - b0.copy_from_slice(&a.val[0..4usize]); - b1.copy_from_slice(&a.val[4usize..8usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn splat_f64x4(self, val: f64) -> f64x4 { @@ -2801,11 +2775,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_f64x4(self, a: f64x4) -> (f64x2, f64x2) { - let mut b0 = [0.0; 2usize]; - let mut b1 = [0.0; 2usize]; - b0.copy_from_slice(&a.val[0..2usize]); - b1.copy_from_slice(&a.val[2usize..4usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extractf128_pd::<0>(a.into()).simd_into(self), + _mm256_extractf128_pd::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn reinterpret_f32_f64x4(self, a: f64x4) -> f32x8 { @@ -2856,11 +2831,12 @@ impl Simd for Avx2 { } #[inline(always)] fn split_mask64x4(self, a: mask64x4) -> (mask64x2, mask64x2) { - let mut b0 = [0; 2usize]; - let mut b1 = [0; 2usize]; - b0.copy_from_slice(&a.val[0..2usize]); - b1.copy_from_slice(&a.val[2usize..4usize]); - (b0.simd_into(self), b1.simd_into(self)) + unsafe { + ( + _mm256_extracti128_si256::<0>(a.into()).simd_into(self), + _mm256_extracti128_si256::<1>(a.into()).simd_into(self), + ) + } } #[inline(always)] fn splat_f32x16(self, a: f32) -> f32x16 { diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 4b905d47..5304f9e8 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::Arch; -use crate::arch::x86::{X86, cast_ident, simple_intrinsic}; +use crate::arch::x86::{X86, cast_ident, coarse_type, intrinsic_ident, simple_intrinsic}; use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; use crate::mk_sse4_2; use crate::ops::{OpSig, TyFlavor, ops_for_type}; @@ -219,8 +219,8 @@ fn make_method( _ => mk_sse4_2::handle_ternary(method_sig, &method_ident, method, vec_ty), }, OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty, scalar_bits), - OpSig::Combine => generic_combine(vec_ty), - OpSig::Split => generic_split(vec_ty), + OpSig::Combine => handle_combine(method_sig, vec_ty), + OpSig::Split => handle_split(method_sig, vec_ty), OpSig::Zip(zip1) => mk_sse4_2::handle_zip(method_sig, vec_ty, scalar_bits, zip1), OpSig::Unzip(select_even) => { mk_sse4_2::handle_unzip(method_sig, vec_ty, scalar_bits, select_even) @@ -240,6 +240,48 @@ fn make_method( } } +pub(crate) fn handle_split(method_sig: TokenStream, vec_ty: &VecType) -> TokenStream { + if vec_ty.n_bits() == 256 { + let extract_op = match vec_ty.scalar { + ScalarType::Float => "extractf128", + _ => "extracti128", + }; + let extract_intrinsic = intrinsic_ident(extract_op, coarse_type(*vec_ty), 256); + quote! { + #method_sig { + unsafe { + ( + #extract_intrinsic::<0>(a.into()).simd_into(self), + #extract_intrinsic::<1>(a.into()).simd_into(self), + ) + } + } + } + } else { + generic_split(vec_ty) + } +} + +pub(crate) fn handle_combine(method_sig: TokenStream, vec_ty: &VecType) -> TokenStream { + if vec_ty.n_bits() == 128 { + let suffix = match (vec_ty.scalar, vec_ty.scalar_bits) { + (ScalarType::Float, 32) => "m128", + (ScalarType::Float, 64) => "m128d", + _ => "m128i", + }; + let set_intrinsic = intrinsic_ident("setr", suffix, 256); + quote! { + #method_sig { + unsafe { + #set_intrinsic(a.into(), b.into()).simd_into(self) + } + } + } + } else { + generic_combine(vec_ty) + } +} + pub(crate) fn handle_compare( method_sig: TokenStream, method: &str, From 1fee324fa8616d43d1b61b6a1fcdd021e54bb723 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 18:50:57 -0500 Subject: [PATCH 06/11] Add 256-bit widen/narrow tests --- fearless_simd_tests/tests/harness/mod.rs | 36 ++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/fearless_simd_tests/tests/harness/mod.rs b/fearless_simd_tests/tests/harness/mod.rs index 9280e113..35cd6e07 100644 --- a/fearless_simd_tests/tests/harness/mod.rs +++ b/fearless_simd_tests/tests/harness/mod.rs @@ -2344,6 +2344,42 @@ fn narrow_u16x16(simd: S) { ); } +#[simd_test] +fn widen_u8x32(simd: S) { + let a = u8x32::from_slice( + simd, + &[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, + ], + ); + assert_eq!( + simd.widen_u8x32(a).val, + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31 + ] + ); +} + +#[simd_test] +fn narrow_u16x32(simd: S) { + let a = u16x32::from_slice( + simd, + &[ + 0, 1, 127, 128, 255, 256, 300, 1000, 128, 192, 224, 240, 248, 252, 254, 255, 100, 200, + 255, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65535, 0, 1, 2, 3, + ], + ); + assert_eq!( + simd.narrow_u16x32(a).val, + [ + 0, 1, 127, 128, 255, 0, 44, 232, 128, 192, 224, 240, 248, 252, 254, 255, 100, 200, 255, + 0, 0, 0, 0, 0, 0, 0, 0, 255, 0, 1, 2, 3 + ] + ); +} + #[simd_test] fn abs_f64x2(simd: S) { let a = f64x2::from_slice(simd, &[-1.5, 2.5]); From bd31e478590985cb09f0b380dae008e4c5aef9d0 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 20:28:42 -0500 Subject: [PATCH 07/11] Add specialized AVX2 widen/narrow ops --- fearless_simd/src/generated/avx2.rs | 40 ++++++---- fearless_simd_gen/src/mk_avx2.rs | 119 ++++++++++++++++++++++++++-- fearless_simd_gen/src/mk_sse4_2.rs | 6 +- 3 files changed, 138 insertions(+), 27 deletions(-) diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 44612fa2..5fed3847 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -494,12 +494,7 @@ impl Simd for Avx2 { } #[inline(always)] fn widen_u8x16(self, a: u8x16) -> u16x16 { - unsafe { - let raw = a.into(); - let high = _mm_cvtepu8_epi16(raw).simd_into(self); - let low = _mm_cvtepu8_epi16(_mm_srli_si128::<8>(raw)).simd_into(self); - self.combine_u16x8(high, low) - } + unsafe { _mm256_cvtepu8_epi16(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u32_u8x16(self, a: u8x16) -> u32x4 { @@ -1824,8 +1819,12 @@ impl Simd for Avx2 { } #[inline(always)] fn widen_u8x32(self, a: u8x32) -> u16x32 { - let (a0, a1) = self.split_u8x32(a); - self.combine_u16x16(self.widen_u8x16(a0), self.widen_u8x16(a1)) + unsafe { + let (a0, a1) = self.split_u8x32(a); + let high = _mm256_cvtepu8_epi16(a0.into()).simd_into(self); + let low = _mm256_cvtepu8_epi16(a1.into()).simd_into(self); + self.combine_u16x16(high, low) + } } #[inline(always)] fn reinterpret_u32_u8x32(self, a: u8x32) -> u32x8 { @@ -2201,13 +2200,14 @@ impl Simd for Avx2 { } #[inline(always)] fn narrow_u16x16(self, a: u16x16) -> u8x16 { - let (a, b) = self.split_u16x16(a); unsafe { - let mask = _mm_set1_epi16(0xFF); - let lo_masked = _mm_and_si128(a.into(), mask); - let hi_masked = _mm_and_si128(b.into(), mask); - let result = _mm_packus_epi16(lo_masked, hi_masked); - result.simd_into(self) + let mask = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, 0, 2, 4, 6, 8, 10, 12, + 14, -1, -1, -1, -1, -1, -1, -1, -1, + ); + let shuffled = _mm256_shuffle_epi8(a.into(), mask); + let packed = _mm256_permute4x64_epi64::<0b11_01_10_00>(shuffled); + _mm256_castsi256_si128(packed).simd_into(self) } } #[inline(always)] @@ -3781,8 +3781,16 @@ impl Simd for Avx2 { } #[inline(always)] fn narrow_u16x32(self, a: u16x32) -> u8x32 { - let (a0, a1) = self.split_u16x32(a); - self.combine_u8x16(self.narrow_u16x16(a0), self.narrow_u16x16(a1)) + let (a, b) = self.split_u16x32(a); + unsafe { + let mask = _mm256_set1_epi16(0xFF); + let lo_masked = _mm256_and_si256(a.into(), mask); + let hi_masked = _mm256_and_si256(b.into(), mask); + let result = _mm256_permute4x64_epi64::<0b_11_01_10_00>(_mm256_packus_epi16( + lo_masked, hi_masked, + )); + result.simd_into(self) + } } #[inline(always)] fn reinterpret_u8_u16x32(self, a: u16x32) -> u8x64 { diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 5304f9e8..8ff93729 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -2,13 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::arch::Arch; -use crate::arch::x86::{X86, cast_ident, coarse_type, intrinsic_ident, simple_intrinsic}; +use crate::arch::x86::{ + X86, cast_ident, coarse_type, extend_intrinsic, intrinsic_ident, pack_intrinsic, + set1_intrinsic, simple_intrinsic, +}; use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; use crate::mk_sse4_2; use crate::ops::{OpSig, TyFlavor, ops_for_type}; use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote}; #[derive(Clone, Copy)] pub(crate) struct Level; @@ -80,7 +83,8 @@ fn mk_simd_impl() -> TokenStream { let mut methods = vec![]; for vec_ty in SIMD_TYPES { for (method, sig) in ops_for_type(vec_ty, true) { - let too_wide = vec_ty.n_bits() > 256; + let too_wide = (vec_ty.n_bits() > 256 && !matches!(method, "split" | "narrow")) + || vec_ty.n_bits() > 512; let acceptable_wide_op = matches!(method, "load_interleaved_128") || matches!(method, "store_interleaved_128"); @@ -189,11 +193,7 @@ fn make_method( OpSig::Compare => handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits, arch), OpSig::Unary => mk_sse4_2::handle_unary(method_sig, method, vec_ty, arch), OpSig::WidenNarrow(t) => { - if vec_ty.n_bits() > 128 && method == "widen" { - generic_op(method, sig, vec_ty) - } else { - mk_sse4_2::handle_widen_narrow(method_sig, method, vec_ty, scalar_bits, ty_bits, t) - } + handle_widen_narrow(method_sig, method, vec_ty, scalar_bits, ty_bits, t) } OpSig::Binary => mk_sse4_2::handle_binary(method_sig, method, vec_ty, arch), OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty, scalar_bits, ty_bits), @@ -315,3 +315,106 @@ pub(crate) fn handle_compare( mk_sse4_2::handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits, arch) } } + +pub(crate) fn handle_widen_narrow( + method_sig: TokenStream, + method: &str, + vec_ty: &VecType, + scalar_bits: usize, + ty_bits: usize, + t: VecType, +) -> TokenStream { + let expr = match method { + "widen" => { + let dst_width = t.n_bits(); + match (dst_width, ty_bits) { + (256, 128) => { + let extend = + extend_intrinsic(vec_ty.scalar, scalar_bits, t.scalar_bits, dst_width); + quote! { + unsafe { + #extend(a.into()).simd_into(self) + } + } + } + (512, 256) => { + let extend = + extend_intrinsic(vec_ty.scalar, scalar_bits, t.scalar_bits, ty_bits); + let combine = format_ident!( + "combine_{}", + VecType { + len: vec_ty.len / 2, + scalar_bits: scalar_bits * 2, + ..*vec_ty + } + .rust_name() + ); + let split = format_ident!("split_{}", vec_ty.rust_name()); + quote! { + unsafe { + let (a0, a1) = self.#split(a); + let high = #extend(a0.into()).simd_into(self); + let low = #extend(a1.into()).simd_into(self); + self.#combine(high, low) + } + } + } + _ => unimplemented!(), + } + } + "narrow" => { + let dst_width = t.n_bits(); + match (dst_width, ty_bits) { + (128, 256) => { + let mask = match t.scalar_bits { + 8 => { + quote! { 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1 } + } + _ => unimplemented!(), + }; + quote! { + unsafe { + let mask = _mm256_setr_epi8(#mask, #mask); + + let shuffled = _mm256_shuffle_epi8(a.into(), mask); + let packed = _mm256_permute4x64_epi64::<0b11_01_10_00>(shuffled); + + _mm256_castsi256_si128(packed).simd_into(self) + } + } + } + (256, 512) => { + let mask = set1_intrinsic(vec_ty.scalar, scalar_bits, t.n_bits()); + let pack = pack_intrinsic( + scalar_bits, + matches!(vec_ty.scalar, ScalarType::Int), + t.n_bits(), + ); + let split = format_ident!("split_{}", vec_ty.rust_name()); + quote! { + let (a, b) = self.#split(a); + unsafe { + // Note that AVX2 only has an intrinsic for saturating cast, + // but not wrapping. + let mask = #mask(0xFF); + let lo_masked = _mm256_and_si256(a.into(), mask); + let hi_masked = _mm256_and_si256(b.into(), mask); + // The 256-bit version of packus_epi16 operates lane-wise, so we need to arrange things + // properly afterwards. + let result = _mm256_permute4x64_epi64::<0b_11_01_10_00>(#pack(lo_masked, hi_masked)); + result.simd_into(self) + } + } + } + _ => unimplemented!(), + } + } + _ => unreachable!(), + }; + + quote! { + #method_sig { + #expr + } + } +} diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index db158a31..1d1e969f 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -83,13 +83,13 @@ fn mk_simd_impl() -> TokenStream { let mut methods = vec![]; for vec_ty in SIMD_TYPES { for (method, sig) in ops_for_type(vec_ty, true) { - let b1 = (vec_ty.n_bits() > 128 && !matches!(method, "split" | "narrow")) + let too_wide = (vec_ty.n_bits() > 128 && !matches!(method, "split" | "narrow")) || vec_ty.n_bits() > 256; - let b2 = !matches!(method, "load_interleaved_128") + let acceptable_wide_op = !matches!(method, "load_interleaved_128") && !matches!(method, "store_interleaved_128"); - if b1 && b2 { + if too_wide && acceptable_wide_op { methods.push(generic_op(method, sig, vec_ty)); continue; } From 843fc66b846bfaea0bce8037766355d26476150e Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 20:58:21 -0500 Subject: [PATCH 08/11] Fix spontaneous Clippy complaint --- fearless_simd_gen/src/mk_simd_types.rs | 45 +++++++++++++------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/fearless_simd_gen/src/mk_simd_types.rs b/fearless_simd_gen/src/mk_simd_types.rs index b1935285..ee71e374 100644 --- a/fearless_simd_gen/src/mk_simd_types.rs +++ b/fearless_simd_gen/src/mk_simd_types.rs @@ -207,29 +207,28 @@ fn simd_impl(ty: &VecType) -> TokenStream { | OpSig::Cvt(_, _) | OpSig::Reinterpret(_, _) | OpSig::Shift - ) { - if let Some(args) = sig.vec_trait_args() { - let ret_ty = sig.ret_ty(ty, TyFlavor::VecImpl); - let call_args = match sig { - OpSig::Unary | OpSig::Cvt(_, _) | OpSig::Reinterpret(_, _) => quote! { self }, - OpSig::Binary | OpSig::Compare | OpSig::Combine => { - quote! { self, rhs.simd_into(self.simd) } - } - OpSig::Shift => { - quote! { self, shift } - } - OpSig::Ternary => { - quote! { self, op1.simd_into(self.simd), op2.simd_into(self.simd) } - } - _ => quote! { todo!() }, - }; - methods.push(quote! { - #[inline(always)] - pub fn #method_name(#args) -> #ret_ty { - self.simd.#trait_method(#call_args) - } - }); - } + ) && let Some(args) = sig.vec_trait_args() + { + let ret_ty = sig.ret_ty(ty, TyFlavor::VecImpl); + let call_args = match sig { + OpSig::Unary | OpSig::Cvt(_, _) | OpSig::Reinterpret(_, _) => quote! { self }, + OpSig::Binary | OpSig::Compare | OpSig::Combine => { + quote! { self, rhs.simd_into(self.simd) } + } + OpSig::Shift => { + quote! { self, shift } + } + OpSig::Ternary => { + quote! { self, op1.simd_into(self.simd), op2.simd_into(self.simd) } + } + _ => quote! { todo!() }, + }; + methods.push(quote! { + #[inline(always)] + pub fn #method_name(#args) -> #ret_ty { + self.simd.#trait_method(#call_args) + } + }); } } let vec_impl = simd_vec_impl(ty); From d0bff932a8b5b25dd8771d77cf979c48e482dde5 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Tue, 11 Nov 2025 21:24:22 -0500 Subject: [PATCH 09/11] Implement 8-bit multiplication in x86 --- fearless_simd/src/generated/avx2.rs | 53 ++++++++++++++++++++---- fearless_simd/src/generated/sse4_2.rs | 27 ++++++++---- fearless_simd_gen/src/mk_avx2.rs | 7 ---- fearless_simd_gen/src/mk_sse4_2.rs | 21 ++++++---- fearless_simd_tests/tests/harness/mod.rs | 40 ++++++++++++++++++ 5 files changed, 117 insertions(+), 31 deletions(-) diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 5fed3847..9ddbc876 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -3,11 +3,6 @@ // This file is autogenerated by fearless_simd_gen -#![expect( - unused_variables, - clippy::todo, - reason = "TODO: https://github.com/linebender/fearless_simd/issues/40" -)] use crate::{Level, Simd, SimdFrom, SimdInto, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, @@ -238,7 +233,16 @@ impl Simd for Avx2 { } #[inline(always)] fn mul_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { - todo!() + unsafe { + let dst_even = _mm_mullo_epi16(a.into(), b.into()); + let dst_odd = + _mm_mullo_epi16(_mm_srli_epi16::<8>(a.into()), _mm_srli_epi16::<8>(b.into())); + _mm_or_si128( + _mm_slli_epi16(dst_odd, 8), + _mm_and_si128(dst_even, _mm_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { @@ -378,7 +382,16 @@ impl Simd for Avx2 { } #[inline(always)] fn mul_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { - todo!() + unsafe { + let dst_even = _mm_mullo_epi16(a.into(), b.into()); + let dst_odd = + _mm_mullo_epi16(_mm_srli_epi16::<8>(a.into()), _mm_srli_epi16::<8>(b.into())); + _mm_or_si128( + _mm_slli_epi16(dst_odd, 8), + _mm_and_si128(dst_even, _mm_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { @@ -1495,7 +1508,18 @@ impl Simd for Avx2 { } #[inline(always)] fn mul_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { - todo!() + unsafe { + let dst_even = _mm256_mullo_epi16(a.into(), b.into()); + let dst_odd = _mm256_mullo_epi16( + _mm256_srli_epi16::<8>(a.into()), + _mm256_srli_epi16::<8>(b.into()), + ); + _mm256_or_si256( + _mm256_slli_epi16(dst_odd, 8), + _mm256_and_si256(dst_even, _mm256_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_i8x32(self, a: i8x32, b: i8x32) -> i8x32 { @@ -1669,7 +1693,18 @@ impl Simd for Avx2 { } #[inline(always)] fn mul_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { - todo!() + unsafe { + let dst_even = _mm256_mullo_epi16(a.into(), b.into()); + let dst_odd = _mm256_mullo_epi16( + _mm256_srli_epi16::<8>(a.into()), + _mm256_srli_epi16::<8>(b.into()), + ); + _mm256_or_si256( + _mm256_slli_epi16(dst_odd, 8), + _mm256_and_si256(dst_even, _mm256_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_u8x32(self, a: u8x32, b: u8x32) -> u8x32 { diff --git a/fearless_simd/src/generated/sse4_2.rs b/fearless_simd/src/generated/sse4_2.rs index c9250c89..b6d53c39 100644 --- a/fearless_simd/src/generated/sse4_2.rs +++ b/fearless_simd/src/generated/sse4_2.rs @@ -3,11 +3,6 @@ // This file is autogenerated by fearless_simd_gen -#![expect( - unused_variables, - clippy::todo, - reason = "TODO: https://github.com/linebender/fearless_simd/issues/40" -)] use crate::{Level, Simd, SimdFrom, SimdInto, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, @@ -246,7 +241,16 @@ impl Simd for Sse4_2 { } #[inline(always)] fn mul_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { - todo!() + unsafe { + let dst_even = _mm_mullo_epi16(a.into(), b.into()); + let dst_odd = + _mm_mullo_epi16(_mm_srli_epi16::<8>(a.into()), _mm_srli_epi16::<8>(b.into())); + _mm_or_si128( + _mm_slli_epi16(dst_odd, 8), + _mm_and_si128(dst_even, _mm_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_i8x16(self, a: i8x16, b: i8x16) -> i8x16 { @@ -389,7 +393,16 @@ impl Simd for Sse4_2 { } #[inline(always)] fn mul_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { - todo!() + unsafe { + let dst_even = _mm_mullo_epi16(a.into(), b.into()); + let dst_odd = + _mm_mullo_epi16(_mm_srli_epi16::<8>(a.into()), _mm_srli_epi16::<8>(b.into())); + _mm_or_si128( + _mm_slli_epi16(dst_odd, 8), + _mm_and_si128(dst_even, _mm_set1_epi16(0xFF)), + ) + .simd_into(self) + } } #[inline(always)] fn and_u8x16(self, a: u8x16, b: u8x16) -> u8x16 { diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 8ff93729..88007383 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -33,13 +33,6 @@ pub(crate) fn mk_avx2_impl() -> TokenStream { let ty_impl = mk_type_impl(); quote! { - // Until we have implemented all functions. - #![expect( - unused_variables, - clippy::todo, - reason = "TODO: https://github.com/linebender/fearless_simd/issues/40" - )] - #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index 1d1e969f..a5810c51 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -33,13 +33,6 @@ pub(crate) fn mk_sse4_2_impl() -> TokenStream { let ty_impl = mk_type_impl(); quote! { - // Until we have implemented all functions. - #![expect( - unused_variables, - clippy::todo, - reason = "TODO: https://github.com/linebender/fearless_simd/issues/40" - )] - #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] @@ -429,9 +422,21 @@ pub(crate) fn handle_binary( arch: impl Arch, ) -> TokenStream { if method == "mul" && vec_ty.scalar_bits == 8 { + // https://stackoverflow.com/questions/8193601/sse-multiplication-16-x-uint8-t + let mullo = intrinsic_ident("mullo", "epi16", vec_ty.n_bits()); + let set1 = intrinsic_ident("set1", "epi16", vec_ty.n_bits()); + let and = intrinsic_ident("and", coarse_type(*vec_ty), vec_ty.n_bits()); + let or = intrinsic_ident("or", coarse_type(*vec_ty), vec_ty.n_bits()); + let slli = intrinsic_ident("slli", "epi16", vec_ty.n_bits()); + let srli = intrinsic_ident("srli", "epi16", vec_ty.n_bits()); quote! { #method_sig { - todo!() + unsafe { + let dst_even = #mullo(a.into(), b.into()); + let dst_odd = #mullo(#srli::<8>(a.into()), #srli::<8>(b.into())); + + #or(#slli(dst_odd, 8), #and(dst_even, #set1(0xFF))).simd_into(self) + } } } } else { diff --git a/fearless_simd_tests/tests/harness/mod.rs b/fearless_simd_tests/tests/harness/mod.rs index 35cd6e07..c7706348 100644 --- a/fearless_simd_tests/tests/harness/mod.rs +++ b/fearless_simd_tests/tests/harness/mod.rs @@ -2447,6 +2447,46 @@ fn trunc_f64x2(simd: S) { assert_eq!(a.trunc().val, [1.0, -2.0]); } +#[simd_test] +fn mul_u8x16(simd: S) { + let a = u8x16::from_slice( + simd, + &[0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 50, 60, 100], + ); + let b = u8x16::from_slice( + simd, + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2], + ); + + assert_eq!( + (a * b).val, + [ + 0, 2, 6, 12, 20, 30, 70, 120, 180, 250, 74, 164, 8, 188, 132, 200 + ] + ); +} + +#[simd_test] +fn mul_i8x16(simd: S) { + let a = i8x16::from_slice( + simd, + &[ + 0, 1, -2, 3, -4, 5, 10, -15, 20, -25, 30, 35, -40, 50, -60, 100, + ], + ); + let b = i8x16::from_slice( + simd, + &[1, 2, 3, -4, 5, -6, 7, 8, 9, 10, -11, 12, 13, -14, 15, 2], + ); + + assert_eq!( + (a * b).val, + [ + 0, 2, -6, -12, -20, -30, 70, -120, -76, 6, -74, -92, -8, 68, 124, -56 + ] + ); +} + #[simd_test] fn mul_u16x8(simd: S) { let a = u16x8::from_slice(simd, &[0, 5, 10, 30, 500, 0, 0, 0]); From b29bf2ad1409c26fb226d197a3368c3a85d8973d Mon Sep 17 00:00:00 2001 From: valadaptive Date: Wed, 12 Nov 2025 09:49:32 -0500 Subject: [PATCH 10/11] Fix meaning of acceptable_wide_op in SSE4.2 gen --- fearless_simd_gen/src/mk_sse4_2.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index a5810c51..1589cb3b 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -79,10 +79,10 @@ fn mk_simd_impl() -> TokenStream { let too_wide = (vec_ty.n_bits() > 128 && !matches!(method, "split" | "narrow")) || vec_ty.n_bits() > 256; - let acceptable_wide_op = !matches!(method, "load_interleaved_128") - && !matches!(method, "store_interleaved_128"); + let acceptable_wide_op = matches!(method, "load_interleaved_128") + || matches!(method, "store_interleaved_128"); - if too_wide && acceptable_wide_op { + if too_wide && !acceptable_wide_op { methods.push(generic_op(method, sig, vec_ty)); continue; } From 2e4235db62639b06c17ee5cf88934f7422b0fda1 Mon Sep 17 00:00:00 2001 From: valadaptive Date: Thu, 13 Nov 2025 19:24:15 -0500 Subject: [PATCH 11/11] Add CHANGELOG entries --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14c1fb20..784b051f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,13 @@ This release has an [MSRV][] of 1.88. ### Added - All vector types now implement `Index` and `IndexMut`. ([#112][] by [@Ralith][]) +- 256-bit vector types now use native AVX2 intrinsics on supported platforms. ([#115][] by [@valadaptive][]) +- 8-bit integer multiplication is now implemented on x86. ([#115][] by [@valadaptive][]) + +### Fixed + +- Integer equality comparisons now function properly on x86. Previously, they performed "greater than" comparisons. + ([#115][] by [@valadaptive][]) ### Changed @@ -27,6 +34,7 @@ This release has an [MSRV][] of 1.88. A consequence of this is that the available variants on `Level` are now dependent on the target features you are compiling with. The fallback level can be restored with the `force_support_fallback` cargo feature. We don't expect this to be necessary outside of tests. +- Code generation for `select` and `unzip` operations on x86 has been improved. ([#115][] by [@valadaptive][]) ### Removed @@ -86,6 +94,7 @@ No changelog was kept for this release. [@Ralith]: https://github.com/Ralith [@DJMcNab]: https://github.com/DJMcNab +[@valadaptive]: https://github.com/valadaptive [#75]: https://github.com/linebender/fearless_simd/pull/75 [#76]: https://github.com/linebender/fearless_simd/pull/76 @@ -103,6 +112,7 @@ No changelog was kept for this release. [#96]: https://github.com/linebender/fearless_simd/pull/96 [#99]: https://github.com/linebender/fearless_simd/pull/99 [#105]: https://github.com/linebender/fearless_simd/pull/105 +[#115]: https://github.com/linebender/fearless_simd/pull/115 [Unreleased]: https://github.com/linebender/fearless_simd/compare/v0.3.0...HEAD [0.3.0]: https://github.com/linebender/fearless_simd/compare/v0.3.0...v0.2.0