diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index 2a663d300a..615821b2e3 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -664,54 +664,35 @@ fn convert_float_to_int<'tcx>( Ok(()) } -/// Splits `left`, `right` and `dest` (which must be SIMD vectors) -/// into 128-bit chuncks. -/// -/// `left`, `right` and `dest` cannot have different types. +/// Splits `op` (which must be a SIMD vector) into 128-bit chuncks. /// /// Returns a tuple where: /// * The first element is the number of 128-bit chunks (let's call it `N`). /// * The second element is the number of elements per chunk (let's call it `M`). -/// * The third element is the `left` vector split into chunks, i.e, it's -/// type is `[[T; M]; N]`. -/// * The fourth element is the `right` vector split into chunks. -/// * The fifth element is the `dest` vector split into chunks. -fn split_simd_to_128bit_chunks<'tcx>( +/// * The third element is the `op` vector split into chunks, i.e, it's +/// type is `[[T; M]; N]` where `T` is the element type of `op`. +fn split_simd_to_128bit_chunks<'tcx, P: Projectable<'tcx, Provenance>>( this: &mut crate::MiriInterpCx<'_, 'tcx>, - left: &OpTy<'tcx, Provenance>, - right: &OpTy<'tcx, Provenance>, - dest: &MPlaceTy<'tcx, Provenance>, -) -> InterpResult< - 'tcx, - (u64, u64, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>), -> { - assert_eq!(dest.layout, left.layout); - assert_eq!(dest.layout, right.layout); + op: &P, +) -> InterpResult<'tcx, (u64, u64, P)> { + let simd_layout = op.layout(); + let (simd_len, element_ty) = simd_layout.ty.simd_size_and_type(this.tcx.tcx); - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.mplace_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - assert_eq!(dest.layout.size.bits() % 128, 0); - let num_chunks = dest.layout.size.bits() / 128; - assert_eq!(dest_len.checked_rem(num_chunks), Some(0)); - let items_per_chunk = dest_len.checked_div(num_chunks).unwrap(); + assert_eq!(simd_layout.size.bits() % 128, 0); + let num_chunks = simd_layout.size.bits() / 128; + let items_per_chunk = simd_len.checked_div(num_chunks).unwrap(); // Transmute to `[[T; items_per_chunk]; num_chunks]` - let element_layout = left.layout.field(this, 0); - let chunked_layout = this.layout_of(Ty::new_array( - this.tcx.tcx, - Ty::new_array(this.tcx.tcx, element_layout.ty, items_per_chunk), - num_chunks, - ))?; - let left = left.transmute(chunked_layout, this)?; - let right = right.transmute(chunked_layout, this)?; - let dest = dest.transmute(chunked_layout, this)?; - - Ok((num_chunks, items_per_chunk, left, right, dest)) + let chunked_layout = this + .layout_of(Ty::new_array( + this.tcx.tcx, + Ty::new_array(this.tcx.tcx, element_ty, items_per_chunk), + num_chunks, + )) + .unwrap(); + let chunked_op = op.transmute(chunked_layout, this)?; + + Ok((num_chunks, items_per_chunk, chunked_op)) } /// Horizontaly performs `which` operation on adjacent values of @@ -731,8 +712,12 @@ fn horizontal_bin_op<'tcx>( right: &OpTy<'tcx, Provenance>, dest: &MPlaceTy<'tcx, Provenance>, ) -> InterpResult<'tcx, ()> { - let (num_chunks, items_per_chunk, left, right, dest) = - split_simd_to_128bit_chunks(this, left, right, dest)?; + assert_eq!(left.layout, dest.layout); + assert_eq!(right.layout, dest.layout); + + let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(this, left)?; + let (_, _, right) = split_simd_to_128bit_chunks(this, right)?; + let (_, _, dest) = split_simd_to_128bit_chunks(this, dest)?; let middle = items_per_chunk / 2; for i in 0..num_chunks { @@ -779,8 +764,12 @@ fn conditional_dot_product<'tcx>( imm: &OpTy<'tcx, Provenance>, dest: &MPlaceTy<'tcx, Provenance>, ) -> InterpResult<'tcx, ()> { - let (num_chunks, items_per_chunk, left, right, dest) = - split_simd_to_128bit_chunks(this, left, right, dest)?; + assert_eq!(left.layout, dest.layout); + assert_eq!(right.layout, dest.layout); + + let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(this, left)?; + let (_, _, right) = split_simd_to_128bit_chunks(this, right)?; + let (_, _, dest) = split_simd_to_128bit_chunks(this, dest)?; let element_layout = left.layout.field(this, 0).field(this, 0); assert!(items_per_chunk <= 4);