From 16eaf64756cc66ff5b7bd1b3a21e1cdbecaaee74 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 24 Sep 2025 21:22:00 -0700 Subject: [PATCH 01/12] Switch rspirv to the latest git version --- Cargo.lock | 10 ++++------ Cargo.toml | 2 ++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7ccdeeccbe..3f6f281d55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3584,9 +3584,8 @@ dependencies = [ [[package]] name = "rspirv" -version = "0.12.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +version = "0.12.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" dependencies = [ "rustc-hash 1.1.0", "spirv", @@ -3961,9 +3960,8 @@ dependencies = [ [[package]] name = "spirv" -version = "0.3.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" +version = "0.3.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" dependencies = [ "bitflags 2.9.4", "serde", diff --git a/Cargo.toml b/Cargo.toml index bd58ec379e..24360b6ace 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -254,6 +254,8 @@ ndk-sys = "0.6" # These overrides allow our examples to explicitly depend on release crates [patch.crates-io] wgpu = { path = "./wgpu" } +rspirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" } +spirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" } [profile.release] lto = "thin" From d035ecb72a574df169965c3f749fadacf392854c Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 16 Sep 2025 23:32:57 -0700 Subject: [PATCH 02/12] Add Cooperative* type to IR --- naga/src/back/glsl/mod.rs | 3 ++- naga/src/back/msl/writer.rs | 23 ++++++++++++++++ naga/src/back/spv/writer.rs | 2 ++ naga/src/common/wgsl/to_wgsl.rs | 19 ++++++++++++++ naga/src/common/wgsl/types.rs | 13 +++++++++ naga/src/compact/types.rs | 2 ++ naga/src/front/wgsl/lower/conversion.rs | 2 ++ naga/src/ir/mod.rs | 35 +++++++++++++++++++++++++ naga/src/proc/layouter.rs | 18 +++++++++++++ naga/src/proc/type_methods.rs | 9 ++++++- naga/src/valid/handles.rs | 1 + naga/src/valid/mod.rs | 1 + naga/src/valid/type.rs | 19 ++++++++++++++ 13 files changed, 145 insertions(+), 2 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 4c5a9d8cbc..515ffd7b1e 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -1107,7 +1107,8 @@ impl<'a, W: Write> Writer<'a, W> { TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, // Write all variants instead of `_` so that if new variants are added a // no exhaustiveness error is thrown - TypeInner::Pointer { .. } + TypeInner::CooperativeMatrix { .. } + | TypeInner::Pointer { .. } | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index cec9226541..9da4d51826 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -235,6 +235,20 @@ impl Display for TypeContext<'_> { rows, scalar, } => put_numeric_type(out, scalar, &[rows, columns]), + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => { + write!( + out, + "{}::simdgroup_{}{}x{}", + NAMESPACE, + scalar.to_msl_name(), + columns as u32, + rows as u32, + ) + } crate::TypeInner::Pointer { base, space } => { let sub = Self { handle: base, @@ -528,6 +542,14 @@ impl crate::Scalar { } } +impl crate::CooperativeScalar { + const fn to_msl_name(self) -> &'static str { + match self { + Self::F32 => "float", + } + } +} + const fn separate(need_separator: bool) -> &'static str { if need_separator { "," @@ -637,6 +659,7 @@ impl crate::Type { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::Pointer { .. } | Ti::ValuePointer { .. } => self.name.is_some(), diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 636766d1e5..16831bc45f 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -436,6 +436,7 @@ impl Writer { // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } + crate::TypeInner::CooperativeMatrix { .. } => return None, crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { @@ -1500,6 +1501,7 @@ impl Writer { | crate::TypeInner::Atomic(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 72be441288..91a04be7af 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -299,6 +299,25 @@ impl TryToWgsl for crate::Scalar { } } +impl TryToWgsl for crate::CooperativeScalar { + const DESCRIPTION: &'static str = "cooperative scalar type"; + + fn try_to_wgsl(self) -> Option<&'static str> { + use crate::CooperativeScalar; + + Some(match self { + CooperativeScalar::F32 => "f32", + }) + } + + fn to_wgsl_for_diagnostics(self) -> String { + match self.try_to_wgsl() { + Some(static_string) => static_string.to_string(), + None => unreachable!(), + } + } +} + impl ToWgsl for crate::ImageDimension { fn to_wgsl(self) -> &'static str { use crate::ImageDimension as IDim; diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 82b8eeaa67..93a94205d7 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -317,6 +317,19 @@ where ctx.write_scalar(scalar, out)?; out.write_str(">")?; } + TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => { + write!( + out, + "coop_mat{}x{}<{}>", + columns as u32, + rows as u32, + scalar.try_to_wgsl().unwrap_or_default() + )?; + } TypeInner::Pointer { base, space } => { let (address, maybe_access) = address_space_str(space); // Everything but `AddressSpace::Handle` gives us a `address` name, but diff --git a/naga/src/compact/types.rs b/naga/src/compact/types.rs index 0a1db16f9f..d06558b182 100644 --- a/naga/src/compact/types.rs +++ b/naga/src/compact/types.rs @@ -16,6 +16,7 @@ impl TypeTracer<'_> { Ti::Scalar { .. } | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic { .. } | Ti::ValuePointer { .. } | Ti::Image { .. } @@ -66,6 +67,7 @@ impl ModuleMap { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::ValuePointer { .. } | Ti::Image { .. } diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs index b22692a3cd..9e03ed5c9e 100644 --- a/naga/src/front/wgsl/lower/conversion.rs +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -350,6 +350,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), Ti::Atomic(_) | Ti::Pointer { .. } @@ -375,6 +376,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Atomic(_) => None, Ti::Pointer { base, .. } | Ti::Array { base, .. } => { types[base].inner.automatically_convertible_scalar(types) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 257445952b..e70e1f650d 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -437,6 +437,16 @@ impl From for u32 { } } +/// Number of components in a cooperative vector. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeVectorSize { + Eight = 8, +} + /// Primitive type for a scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -464,6 +474,24 @@ pub enum ScalarKind { AbstractFloat, } +/// Primitive type for a scalar. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeScalar { + F32, +} + +impl CooperativeScalar { + pub const fn width(&self) -> Bytes { + match *self { + Self::F32 => 4, + } + } +} + /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -712,6 +740,13 @@ pub enum TypeInner { rows: VectorSize, scalar: Scalar, }, + /// Matrix that is cooperatively processed by all the threads + /// in an opaque mapping. + CooperativeMatrix { + columns: CooperativeVectorSize, + rows: CooperativeVectorSize, + scalar: CooperativeScalar, + }, /// Atomic scalar. Atomic(Scalar), /// Pointer to another type. diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 204a523c91..5e7aed8a0f 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -86,6 +86,12 @@ impl From for Alignment { } } +impl From for Alignment { + fn from(size: crate::CooperativeVectorSize) -> Self { + Self(unsafe { NonZeroU32::new_unchecked(size as u32) }) + } +} + /// Size and alignment information for a type. #[derive(Clone, Copy, Debug, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -212,6 +218,18 @@ impl Layouter { alignment: Alignment::from(rows) * alignment, } } + Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + } => { + let alignment = Alignment::new(scalar.width() as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { size, alignment: Alignment::ONE, diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c59d524f13..24a14868f9 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -202,6 +202,11 @@ impl crate::TypeInner { rows, scalar, } => Some(super::Alignment::from(rows) * scalar.width as u32 * columns as u32), + Self::CooperativeMatrix { + columns, + rows, + scalar, + } => Some(columns as u32 * rows as u32 * scalar.width() as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { base: _, @@ -361,6 +366,7 @@ impl crate::TypeInner { crate::TypeInner::Scalar(scalar) => Some((None, scalar)), crate::TypeInner::Vector { size, scalar } => Some((Some(size), scalar)), crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Atomic(_) | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } @@ -385,7 +391,8 @@ impl crate::TypeInner { | crate::TypeInner::Matrix { scalar, .. } | crate::TypeInner::Atomic(scalar) => scalar.is_abstract(), crate::TypeInner::Array { base, .. } => types[base].inner.is_abstract(types), - crate::TypeInner::ValuePointer { .. } + crate::TypeInner::CooperativeMatrix { .. } + | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::Struct { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e8a6901343..2cfb32ebe1 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -379,6 +379,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 426b3d637d..bb3408181f 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -451,6 +451,7 @@ impl crate::TypeInner { Self::Scalar { .. } | Self::Vector { .. } | Self::Matrix { .. } + | Self::CooperativeMatrix { .. } | Self::Array { size: crate::ArraySize::Constant(_), .. diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e8b83ff08f..f4a2271ea3 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -415,6 +415,25 @@ impl super::Validator { type_info.push_constant_compatibility = push_constant_compatibility; type_info } + Ti::CooperativeMatrix { + columns: _, + rows: _, + scalar, + } => { + if scalar != crate::CooperativeScalar::F32 { + return Err(TypeError::MatrixElementNotFloat); + } + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED, + Alignment::from_width(scalar.width()), + ) + } Ti::Atomic(scalar) => { match scalar { crate::Scalar { From a05fdf998b3e252e970bb541237d147d3b6430b0 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 17 Sep 2025 00:00:42 -0700 Subject: [PATCH 03/12] coop: first bits of Vulkan support for the type --- naga/src/back/spv/instructions.rs | 16 +++++++++++ naga/src/back/spv/mod.rs | 28 +++++++++++++++++++ naga/src/back/spv/writer.rs | 46 +++++++++++++++++++++++++++---- naga/src/valid/mod.rs | 2 ++ naga/src/valid/type.rs | 1 + 5 files changed, 88 insertions(+), 5 deletions(-) diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 788c3bc119..5e8c22e62d 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -281,6 +281,22 @@ impl super::Instruction { instruction } + pub(super) fn type_coop_matrix( + id: Word, + scalar_type_id: Word, + row_count: crate::CooperativeVectorSize, + column_count: crate::CooperativeVectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); + instruction.set_result(id); + instruction.add_operand(scalar_type_id); + instruction.add_operand(spirv::Scope::Subgroup as u32); + instruction.add_operand(column_count as u32); + instruction.add_operand(row_count as u32); + instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose + instruction + } + #[allow(clippy::too_many_arguments)] pub(super) fn type_image( id: Word, diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 371b3f7dbe..de09b91595 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -340,6 +340,33 @@ impl NumericType { } } +/// A cooperative type, for use in [`LocalType`]. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum CooperativeType { + Matrix { + columns: crate::CooperativeVectorSize, + rows: crate::CooperativeVectorSize, + scalar: crate::CooperativeScalar, + }, +} + +impl CooperativeType { + const fn from_inner(inner: &crate::TypeInner) -> Option { + match *inner { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => Some(Self::Matrix { + columns, + rows, + scalar, + }), + _ => None, + } + } +} + /// A SPIR-V type constructed during code generation. /// /// This is the variant of [`LookupType`] used to represent types that might not @@ -389,6 +416,7 @@ impl NumericType { enum LocalType { /// A numeric type. Numeric(NumericType), + Cooperative(CooperativeType), Pointer { base: Word, class: spirv::StorageClass, diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 16831bc45f..69416ffacb 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -6,10 +6,11 @@ use spirv::Word; use super::{ block::DebugInfoInner, helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, - Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error, - Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType, - LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, - PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, + Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo, + EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, + LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, + NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, + BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, @@ -375,6 +376,12 @@ impl Writer { }) } + pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word { + match scalar { + crate::CooperativeScalar::F32 => self.get_f32_type_id(), + } + } + pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let f32_id = self.get_f32_type_id(); self.get_pointer_type_id(f32_id, class) @@ -436,7 +443,9 @@ impl Writer { // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } - crate::TypeInner::CooperativeMatrix { .. } => return None, + crate::TypeInner::CooperativeMatrix { .. } => { + LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap()) + } crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { @@ -1353,6 +1362,14 @@ impl Writer { self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?; self.use_extension("SPV_KHR_16bit_storage"); } + // Cooperative types and ops + crate::TypeInner::CooperativeMatrix { .. } => { + self.require_any( + "cooperative matrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + self.use_extension("SPV_KHR_cooperative_matrix"); + } _ => {} } Ok(()) @@ -1379,12 +1396,31 @@ impl Writer { instruction.to_words(&mut self.logical_layout.declarations); } + fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) { + let instruction = match coop { + CooperativeType::Matrix { + columns, + rows, + scalar, + } => { + let scalar_id = self.get_cooperative_type_id(scalar); + Instruction::type_coop_matrix(id, scalar_id, rows, columns) + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { let instruction = match local_ty { LocalType::Numeric(numeric) => { self.write_numeric_type_declaration_local(id, numeric); return; } + LocalType::Cooperative(coop) => { + self.write_cooperative_type_declaration_local(id, coop); + return; + } LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base), LocalType::Image(image) => { let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type)); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index bb3408181f..fa0fdb0d39 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -186,6 +186,8 @@ bitflags::bitflags! { /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store /// `f16`-precision values in `f32`s. const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28; + /// Support for cooperative matrix types and operations + const COOPERATIVE_MATRIX = 1 << 29; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index f4a2271ea3..155991b0e8 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -420,6 +420,7 @@ impl super::Validator { rows: _, scalar, } => { + self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; if scalar != crate::CooperativeScalar::F32 { return Err(TypeError::MatrixElementNotFloat); } From 23f0cf8ffd3476faa493cef6f57fc972b3b7f5f2 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 19 Sep 2025 22:38:29 -0700 Subject: [PATCH 04/12] coop: wgsl parsing, IR role --- naga/src/back/msl/writer.rs | 1 + naga/src/back/spv/instructions.rs | 17 +++- naga/src/back/spv/mod.rs | 7 +- naga/src/back/spv/writer.rs | 3 +- naga/src/common/wgsl/to_wgsl.rs | 20 +++-- naga/src/common/wgsl/types.rs | 6 +- naga/src/front/wgsl/error.rs | 12 +++ naga/src/front/wgsl/lower/construction.rs | 26 ++++++ naga/src/front/wgsl/lower/mod.rs | 22 +++++ naga/src/front/wgsl/parse/ast.rs | 22 +++++ naga/src/front/wgsl/parse/lexer.rs | 12 +++ naga/src/front/wgsl/parse/mod.rs | 58 +++++++++++++ naga/src/ir/mod.rs | 21 ++++- naga/src/proc/layouter.rs | 5 +- naga/src/proc/type_methods.rs | 1 + naga/src/valid/type.rs | 1 + naga/tests/in/wgsl/cooperative-matrix.toml | 2 + naga/tests/in/wgsl/cooperative-matrix.wgsl | 7 ++ .../analysis/wgsl-cooperative-matrix.info.ron | 78 +++++++++++++++++ .../ir/wgsl-cooperative-matrix.compact.ron | 84 +++++++++++++++++++ naga/tests/out/ir/wgsl-cooperative-matrix.ron | 84 +++++++++++++++++++ .../out/spv/wgsl-cooperative-matrix.spvasm | 17 ++++ 22 files changed, 486 insertions(+), 20 deletions(-) create mode 100644 naga/tests/in/wgsl/cooperative-matrix.toml create mode 100644 naga/tests/in/wgsl/cooperative-matrix.wgsl create mode 100644 naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron create mode 100644 naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron create mode 100644 naga/tests/out/ir/wgsl-cooperative-matrix.ron create mode 100644 naga/tests/out/spv/wgsl-cooperative-matrix.spvasm diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 9da4d51826..100be84f5c 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -239,6 +239,7 @@ impl Display for TypeContext<'_> { columns, rows, scalar, + role: _, } => { write!( out, diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 5e8c22e62d..bb559606d9 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -284,8 +284,9 @@ impl super::Instruction { pub(super) fn type_coop_matrix( id: Word, scalar_type_id: Word, - row_count: crate::CooperativeVectorSize, - column_count: crate::CooperativeVectorSize, + row_count: crate::CooperativeSize, + column_count: crate::CooperativeSize, + role: spirv::CooperativeMatrixUse, ) -> Self { let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); instruction.set_result(id); @@ -293,7 +294,7 @@ impl super::Instruction { instruction.add_operand(spirv::Scope::Subgroup as u32); instruction.add_operand(column_count as u32); instruction.add_operand(row_count as u32); - instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose + instruction.add_operand(role as u32); instruction } @@ -1305,3 +1306,13 @@ impl From for spirv::Dim { } } } + +impl From for spirv::CooperativeMatrixUse { + fn from(role: crate::CooperativeRole) -> Self { + match role { + crate::CooperativeRole::A => Self::MatrixAKHR, + crate::CooperativeRole::B => Self::MatrixBKHR, + crate::CooperativeRole::C => Self::MatrixAccumulatorKHR, + } + } +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index de09b91595..c283b035a7 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -344,9 +344,10 @@ impl NumericType { #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum CooperativeType { Matrix { - columns: crate::CooperativeVectorSize, - rows: crate::CooperativeVectorSize, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, scalar: crate::CooperativeScalar, + role: crate::CooperativeRole, }, } @@ -357,10 +358,12 @@ impl CooperativeType { columns, rows, scalar, + role, } => Some(Self::Matrix { columns, rows, scalar, + role, }), _ => None, } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 69416ffacb..1b31194fec 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1402,9 +1402,10 @@ impl Writer { columns, rows, scalar, + role, } => { let scalar_id = self.get_cooperative_type_id(scalar); - Instruction::type_coop_matrix(id, scalar_id, rows, columns) + Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into()) } }; diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 91a04be7af..1a3f5e5e17 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -318,15 +318,23 @@ impl TryToWgsl for crate::CooperativeScalar { } } -impl ToWgsl for crate::ImageDimension { +impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { - use crate::ImageDimension as IDim; + match self { + Self::A => "A", + Self::B => "B", + Self::C => "C", + } + } +} +impl ToWgsl for crate::ImageDimension { + fn to_wgsl(self) -> &'static str { match self { - IDim::D1 => "1d", - IDim::D2 => "2d", - IDim::D3 => "3d", - IDim::Cube => "cube", + Self::D1 => "1d", + Self::D2 => "2d", + Self::D3 => "3d", + Self::Cube => "cube", } } } diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 93a94205d7..a678a617f7 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -321,13 +321,15 @@ where columns, rows, scalar, + role, } => { write!( out, - "coop_mat{}x{}<{}>", + "coop_mat{}x{}<{},{}>", columns as u32, rows as u32, - scalar.try_to_wgsl().unwrap_or_default() + scalar.try_to_wgsl().unwrap_or_default(), + role.to_wgsl(), )?; } TypeInner::Pointer { base, space } => { diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0e..8c749acc73 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -412,6 +412,8 @@ pub(crate) enum Error<'a> { TypeTooLarge { span: Span, }, + UnderspecifiedCooperativeMatrix, + UnknownCooperativeScalar(Span), } impl From for Error<'_> { @@ -1386,6 +1388,16 @@ impl<'a> Error<'a> { crate::valid::MAX_TYPE_SIZE )], }, + Error::UnderspecifiedCooperativeMatrix => ParseError { + message: "cooperative matrix constructor is underspecified".into(), + labels: vec![], + notes: vec![format!("must be F32")], + }, + Error::UnknownCooperativeScalar(span) => ParseError { + message: "unknown cooperative scalar type".into(), + labels: vec![(span, "type needs the scalar type specified".into())], + notes: vec![format!("must be F32")], + }, } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 997d5a3123..9ac11bfc98 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -638,6 +638,32 @@ impl<'source> Lowerer<'source, '_> { }; Constructor::Type(ty) } + ast::ConstructorType::PartialCooperativeMatrix { .. } => { + return Err(Box::new(Error::UnderspecifiedCooperativeMatrix)); + } + ast::ConstructorType::CooperativeMatrix { + rows, + columns, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; + let scalar = match ctx.module.types[ty].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }) => crate::CooperativeScalar::F32, + _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + }; + let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + }); + Constructor::Type(ty) + } ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { let base = self.resolve_ast_type(base, &mut ctx.as_const())?; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e90d7eab0a..b599223561 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3955,6 +3955,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))), } } + ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, ctx)?; + let scalar = match ctx.module.types[ty].inner { + ir::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }) => crate::CooperativeScalar::F32, + _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + }; + ir::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + } + } ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), ast::Type::Pointer { base, space } => { let base = self.resolve_ast_type(base, ctx)?; diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c48..af05a84110 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -235,6 +235,13 @@ pub enum Type<'a> { ty: Handle>, ty_span: Span, }, + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, Atomic(Scalar), Pointer { base: Handle>, @@ -385,6 +392,21 @@ pub enum ConstructorType<'a> { ty_span: Span, }, + /// A cooperative matrix construction base `coop_mat8x8(...)`. + PartialCooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + }, + + /// A full cooperative matrix construction `coop_mat8x8(...)`. + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, + /// An array whose component type and size are inferred from the arguments: /// `array(3,4,5)`. PartialArray, diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index d0a8033987..ed87e37100 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -584,6 +584,18 @@ impl<'a> Lexer<'a> { }) } + pub(in crate::front::wgsl) fn next_cooperative_role( + &mut self, + ) -> Result<'a, crate::CooperativeRole> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "A" => Ok(crate::CooperativeRole::A), + "B" => Ok(crate::CooperativeRole::B), + "C" => Ok(crate::CooperativeRole::C), + _ => Err(Box::new(Error::UnknownAccess(span))), + } + } + pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<'a, ()> { self.expect(Token::Paren('(')) } diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30..49d7eaab25 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -658,6 +658,12 @@ impl Parser { ty_span: Span::UNDEFINED, })) } + "coop_mat8x8" => { + return Ok(Some(ast::ConstructorType::PartialCooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Eight, + })) + } "array" => ast::ConstructorType::PartialArray, "atomic" | "binding_array" @@ -701,6 +707,19 @@ impl Parser { ty_span, })) } + ( + Token::Paren('<'), + ast::ConstructorType::PartialCooperativeMatrix { columns, rows }, + ) => { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(Some(ast::ConstructorType::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + })) + } (Token::Paren('<'), ast::ConstructorType::PartialArray) => { lexer.expect_generic_paren('<')?; let base = self.type_decl(lexer, ctx)?; @@ -1437,6 +1456,22 @@ impl Parser { Ok((ty, span)) } + /// Parses ``, returning (T, span of T, R, span of R) + fn cooperative_scalar_and_role<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<'a, (Handle>, Span, crate::CooperativeRole)> { + lexer.expect_generic_paren('<')?; + let start = lexer.start_byte_offset(); + let ty = self.type_decl(lexer, ctx)?; + let ty_span = lexer.span_from(start); + lexer.expect(Token::Separator(','))?; + let role = lexer.next_cooperative_role()?; + lexer.expect_generic_paren('>')?; + Ok((ty, ty_span, role)) + } + fn matrix_with_type<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1453,6 +1488,23 @@ impl Parser { }) } + fn cooperative_matrix_with_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ) -> Result<'a, ast::Type<'a>> { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + }) + } + fn type_decl_impl<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1684,6 +1736,12 @@ impl Parser { ty: ctx.new_scalar(Scalar::F16), ty_span: Span::UNDEFINED, }, + "coop_mat8x8" => self.cooperative_matrix_with_type( + lexer, + ctx, + crate::CooperativeSize::Eight, + crate::CooperativeSize::Eight, + )?, "atomic" => { let scalar = lexer.next_scalar_generic()?; ast::Type::Atomic(scalar) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index e70e1f650d..2338e3bba2 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -443,7 +443,7 @@ impl From for u32 { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum CooperativeVectorSize { +pub enum CooperativeSize { Eight = 8, } @@ -474,7 +474,7 @@ pub enum ScalarKind { AbstractFloat, } -/// Primitive type for a scalar. +/// Primitive type for a cooperative scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -492,6 +492,18 @@ impl CooperativeScalar { } } +/// Role of a cooperative variable in the equation "A * B + C" +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeRole { + A, + B, + C, +} + /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -743,9 +755,10 @@ pub enum TypeInner { /// Matrix that is cooperatively processed by all the threads /// in an opaque mapping. CooperativeMatrix { - columns: CooperativeVectorSize, - rows: CooperativeVectorSize, + columns: CooperativeSize, + rows: CooperativeSize, scalar: CooperativeScalar, + role: CooperativeRole, }, /// Atomic scalar. Atomic(Scalar), diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 5e7aed8a0f..7f9380d766 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -86,8 +86,8 @@ impl From for Alignment { } } -impl From for Alignment { - fn from(size: crate::CooperativeVectorSize) -> Self { +impl From for Alignment { + fn from(size: crate::CooperativeSize) -> Self { Self(unsafe { NonZeroU32::new_unchecked(size as u32) }) } } @@ -222,6 +222,7 @@ impl Layouter { columns: _, rows, scalar, + role: _, } => { let alignment = Alignment::new(scalar.width() as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 24a14868f9..54fec19a89 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -206,6 +206,7 @@ impl crate::TypeInner { columns, rows, scalar, + role: _, } => Some(columns as u32 * rows as u32 * scalar.width() as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 155991b0e8..7a8524bd2e 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -419,6 +419,7 @@ impl super::Validator { columns: _, rows: _, scalar, + role: _, } => { self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; if scalar != crate::CooperativeScalar::F32 { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml new file mode 100644 index 0000000000..1bfa633ff0 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -0,0 +1,2 @@ +targets = "SPIRV" +god_mode = true diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl new file mode 100644 index 0000000000..335034818f --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -0,0 +1,7 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; + +@compute @workgroup_size(8, 8, 1) +fn main() { + //let c = a * b; +} diff --git a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron new file mode 100644 index 0000000000..f806c3f3dd --- /dev/null +++ b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron @@ -0,0 +1,78 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 0, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(0), + ), + ( + uniformity: ( + non_uniform_result: Some(2), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 0, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(2), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(0), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(0), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron new file mode 100644 index 0000000000..1298f69e2c --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -0,0 +1,84 @@ +( + types: [ + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: F32, + role: A, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(0), + Load( + pointer: 0, + ), + GlobalVariable(0), + Load( + pointer: 2, + ), + Binary( + op: Add, + left: 1, + right: 3, + ), + ], + named_expressions: { + 4: "a2", + }, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 5, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron new file mode 100644 index 0000000000..1298f69e2c --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -0,0 +1,84 @@ +( + types: [ + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: F32, + role: A, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(0), + Load( + pointer: 0, + ), + GlobalVariable(0), + Load( + pointer: 2, + ), + Binary( + op: Add, + left: 1, + right: 3, + ), + ], + named_expressions: { + 4: "a2", + }, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 5, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm new file mode 100644 index 0000000000..33e7477e5d --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -0,0 +1,17 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 7 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %4 "main" +OpExecutionMode %4 LocalSize 8 8 1 +%2 = OpTypeVoid +%5 = OpTypeFunction %2 +%4 = OpFunction %2 None %5 +%3 = OpLabel +OpBranch %6 +%6 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file From 8b4a22b0683b870e038df293ef0d59e71a19c125 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 20 Sep 2025 00:20:32 -0700 Subject: [PATCH 05/12] coop: handle simple ops, end-to-end with SPIRV --- naga/src/back/spv/block.rs | 16 +++++++- naga/src/back/spv/instructions.rs | 15 ++++---- naga/src/back/spv/mod.rs | 1 + naga/src/back/spv/writer.rs | 18 ++++++++- naga/src/ir/mod.rs | 9 +++++ naga/src/proc/type_methods.rs | 1 + naga/src/proc/typifier.rs | 25 +++++++++++++ naga/src/valid/expression.rs | 43 ++++++++++++++++++++-- naga/tests/in/wgsl/cooperative-matrix.toml | 4 ++ naga/tests/in/wgsl/cooperative-matrix.wgsl | 4 +- 10 files changed, 120 insertions(+), 16 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 7758d86c41..c80f9035d9 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { crate::TypeInner::Scalar(_) => Dimension::Scalar, crate::TypeInner::Vector { .. } => Dimension::Vector, crate::TypeInner::Matrix { .. } => Dimension::Matrix, + crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix, _ => unreachable!(), } } @@ -766,6 +767,7 @@ impl BlockContext<'_> { rows, scalar, } => { + //TODO: why not just rely on `Fadd` for matrices? self.write_matrix_matrix_column_op( block, id, @@ -781,6 +783,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd, _ => unimplemented!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { @@ -809,6 +812,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub, _ => unimplemented!(), }, crate::BinaryOperator::Multiply => { @@ -842,10 +846,12 @@ impl BlockContext<'_> { (Dimension::Vector, Dimension::Matrix) => { spirv::Op::VectorTimesMatrix } - (Dimension::Matrix, Dimension::Scalar) => { + (Dimension::Matrix, Dimension::Scalar) + | (Dimension::CooperativeMatrix, Dimension::Scalar) => { spirv::Op::MatrixTimesScalar } - (Dimension::Scalar, Dimension::Matrix) => { + (Dimension::Scalar, Dimension::Matrix) + | (Dimension::Scalar, Dimension::CooperativeMatrix) => { reverse_operands = true; spirv::Op::MatrixTimesScalar } @@ -864,6 +870,12 @@ impl BlockContext<'_> { } (Dimension::Vector, Dimension::Vector) | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix) + //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication + | (Dimension::CooperativeMatrix, _) + | (_, Dimension::CooperativeMatrix) => { + unimplemented!() + } } } crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index bb559606d9..9e542917f3 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -284,17 +284,18 @@ impl super::Instruction { pub(super) fn type_coop_matrix( id: Word, scalar_type_id: Word, - row_count: crate::CooperativeSize, - column_count: crate::CooperativeSize, - role: spirv::CooperativeMatrixUse, + scope_id: Word, + row_count_id: Word, + column_count_id: Word, + matrix_use_id: Word, ) -> Self { let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); instruction.set_result(id); instruction.add_operand(scalar_type_id); - instruction.add_operand(spirv::Scope::Subgroup as u32); - instruction.add_operand(column_count as u32); - instruction.add_operand(row_count as u32); - instruction.add_operand(role as u32); + instruction.add_operand(scope_id); + instruction.add_operand(row_count_id); + instruction.add_operand(column_count_id); + instruction.add_operand(matrix_use_id); instruction } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index c283b035a7..f2622c466a 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -482,6 +482,7 @@ enum Dimension { Scalar, Vector, Matrix, + CooperativeMatrix, } /// Key used to look up an operation which we have wrapped in a helper diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 1b31194fec..12392ac29e 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1368,7 +1368,9 @@ impl Writer { "cooperative matrix", &[spirv::Capability::CooperativeMatrixKHR], )?; + self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?; self.use_extension("SPV_KHR_cooperative_matrix"); + self.use_extension("SPV_KHR_vulkan_memory_model"); } _ => {} } @@ -1405,7 +1407,12 @@ impl Writer { role, } => { let scalar_id = self.get_cooperative_type_id(scalar); - Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into()) + let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let columns_id = self.get_index_constant(columns as u32); + let rows_id = self.get_index_constant(rows as u32); + let role_id = + self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32); + Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id) } }; @@ -2669,7 +2676,14 @@ impl Writer { } let addressing_model = spirv::AddressingModel::Logical; - let memory_model = spirv::MemoryModel::GLSL450; + let memory_model = if self + .capabilities_used + .contains(&spirv::Capability::VulkanMemoryModel) + { + spirv::MemoryModel::Vulkan + } else { + spirv::MemoryModel::GLSL450 + }; //self.check(addressing_model.required_capabilities())?; //self.check(memory_model.required_capabilities())?; diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 2338e3bba2..7962a0582e 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -490,6 +490,15 @@ impl CooperativeScalar { Self::F32 => 4, } } + + pub const fn to_scalar(&self) -> Scalar { + match *self { + Self::F32 => Scalar { + kind: ScalarKind::Float, + width: 4, + }, + } + } } /// Role of a cooperative variable in the equation "A * B + C" diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 54fec19a89..c4a9091c74 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -115,6 +115,7 @@ impl crate::TypeInner { match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), Ti::Matrix { scalar, .. } => Some(scalar), + Ti::CooperativeMatrix { scalar, .. } => Some(scalar.to_scalar()), _ => None, } } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 79b4f95e10..89599e079c 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -143,6 +143,17 @@ impl Clone for TypeResolution { columns, scalar, }, + Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + } => Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }, Ti::Pointer { base, space } => Ti::Pointer { base, space }, Ti::ValuePointer { size, @@ -587,6 +598,20 @@ impl<'a> ResolveContext<'a> { (&Ti::Scalar { .. }, _) => res_right.clone(), (_, &Ti::Scalar { .. }) => res_left.clone(), (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(), + ( + &Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + role, + }, + &Ti::CooperativeMatrix { columns, .. }, + ) => TypeResolution::Value(Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }), (tl, tr) => { return Err(ResolveError::IncompatibleOperands(format!( "{tl:?} * {tr:?}" diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 68023b5bf0..8bb9af142b 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -788,7 +788,9 @@ impl super::Validator { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, - Ti::Matrix { .. } => left_inner == right_inner, + Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => { + left_inner == right_inner + } _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { @@ -818,7 +820,7 @@ impl super::Validator { scalar: scalar2, .. }, ) => scalar1 == scalar2, - // Scalar/matrix. + // Scalar * matrix. ( &Ti::Scalar(Sc { kind: Sk::Float, .. @@ -831,7 +833,7 @@ impl super::Validator { kind: Sk::Float, .. }), ) => true, - // Vector/vector. + // Vector * vector. ( &Ti::Vector { size: size1, @@ -864,9 +866,44 @@ impl super::Validator { }, &Ti::Matrix { rows, .. }, ) => size == rows, + // Matrix * matrix. (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { columns == rows } + // Coop matrix * coop matrix. + ( + &Ti::CooperativeMatrix { + columns, + scalar: scalar1, + role: role1, + .. + }, + &Ti::CooperativeMatrix { + rows, + scalar: scalar2, + role: role2, + .. + }, + ) => columns == rows && scalar1 == scalar2 && role1 == role2, + // Scalar * coop matrix. + ( + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + &Ti::CooperativeMatrix { + scalar: crate::CooperativeScalar::F32, + .. + }, + ) + | ( + &Ti::CooperativeMatrix { + scalar: crate::CooperativeScalar::F32, + .. + }, + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + ) => true, _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index 1bfa633ff0..c06d67d21a 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,2 +1,6 @@ targets = "SPIRV" god_mode = true + +[spv] +debug = true +version = [1, 4] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 335034818f..91e371b9fb 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,7 +1,7 @@ var a: coop_mat8x8; -var b: coop_mat8x8; +//var b: coop_mat8x8; @compute @workgroup_size(8, 8, 1) fn main() { - //let c = a * b; + let a2 = a + a; } From d1c9b568cee8f3d6c6cf42f8341d24a94ce7558c Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 21 Sep 2025 23:47:35 -0700 Subject: [PATCH 06/12] coop: mulAdd instruction --- naga/src/back/dot/mod.rs | 6 ++++ naga/src/back/glsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 4 ++- naga/src/back/msl/writer.rs | 7 +++++ naga/src/back/pipeline_constants.rs | 9 ++++++ naga/src/back/spv/block.rs | 11 +++++++ naga/src/back/spv/instructions.rs | 12 ++++++++ naga/src/back/wgsl/writer.rs | 9 ++++++ naga/src/compact/expressions.rs | 12 ++++++++ naga/src/front/wgsl/lower/mod.rs | 11 +++++-- naga/src/ir/mod.rs | 9 ++++++ naga/src/proc/constant_evaluator.rs | 3 ++ naga/src/proc/typifier.rs | 1 + naga/src/valid/analyzer.rs | 5 ++++ naga/src/valid/expression.rs | 35 ++++++++++++++++++++++ naga/src/valid/function.rs | 5 ++-- naga/src/valid/handles.rs | 3 ++ naga/tests/in/wgsl/cooperative-matrix.wgsl | 5 ++-- 18 files changed, 141 insertions(+), 9 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 826dad1c21..d23641f7a0 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -742,6 +742,12 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } + E::MulAdd { a, b, c } => { + edges.insert("a", a); + edges.insert("b", b); + edges.insert("c", c); + ("MulAdd".into(), 6) + } }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 515ffd7b1e..03483cd3d3 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -4341,7 +4341,8 @@ impl<'a, W: Write> Writer<'a, W> { } // not supported yet Expression::RayQueryGetIntersection { .. } - | Expression::RayQueryVertexPositions { .. } => unreachable!(), + | Expression::RayQueryVertexPositions { .. } + | Expression::MulAdd { .. } => unreachable!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index ab95b9327f..0c09e7fbdb 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -4275,7 +4275,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } // Not supported yet - Expression::RayQueryVertexPositions { .. } => unreachable!(), + Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => { + unreachable!() + } // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 100be84f5c..180021a6a7 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -2842,6 +2842,13 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::MulAdd { a, b, c } => { + self.put_expression(a, context, false)?; + write!(self.out, " * ")?; + self.put_expression(b, context, false)?; + write!(self.out, " + ")?; + self.put_expression(c, context, false)?; + } } Ok(()) } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d2b3ed70ed..852683d52a 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,6 +633,15 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } + Expression::MulAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index c80f9035d9..3b5c7ddd1e 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1805,6 +1805,17 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } + crate::Expression::MulAdd { a, b, c } => { + let id = self.gen_id(); + block.body.push(Instruction::coop_mul_add( + result_type_id, + id, + self.cached[a], + self.cached[b], + self.cached[c], + )); + id + } }; self.cached[expr_handle] = id; diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 9e542917f3..3091b6cfee 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1245,6 +1245,18 @@ impl super::Instruction { instruction } + + // Cooperative operations + pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(a); + instruction.add_operand(b); + instruction.add_operand(c); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 225a63343b..a1e9b31847 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1685,6 +1685,15 @@ impl Writer { write!(self.out, ")")? } + Expression::MulAdd { a, b, c } => { + write!(self.out, "mulAdd(")?; + self.write_expr(module, a, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, b, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, c, func_ctx)?; + write!(self.out, ")")? + } // Not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => unreachable!(), diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index f36d747a93..98f3bbc3c9 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,6 +253,9 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } + Ex::MulAdd { a, b, c } => { + self.expressions_used.insert_iter([a, b, c]); + } } } } @@ -419,6 +422,15 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), + Ex::MulAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index b599223561..511150e2ee 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3082,7 +3082,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3106,7 +3105,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3130,6 +3128,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + "coopMulAdd" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let a = self.expression(args.next()?, ctx)?; + let b = self.expression(args.next()?, ctx)?; + let c = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::MulAdd { a, b, c } + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 7962a0582e..6679ef88be 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -1790,6 +1790,15 @@ pub enum Expression { /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, + + /// Return a * b + c. + /// Currently only supported for [`TypeInner::CooperativeMatrix`] types, + /// where it's only valid in uniform control flow. + MulAdd { + a: Handle, + b: Handle, + c: Handle, + }, } /// The value of the switch case. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index f5a5d25ca8..d571183ab3 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -584,6 +584,8 @@ pub enum ConstantEvaluatorError { "Expected reject and accept args. to be scalars of vectors of the same type, got something else", )] SelectAcceptRejectTypeMismatch, + #[error("Cooperative operations can't be constant")] + CooperativeOperation, } impl<'a> ConstantEvaluator<'a> { @@ -971,6 +973,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } + Expression::MulAdd { .. } => Err(ConstantEvaluatorError::CooperativeOperation), } } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 89599e079c..f90f36de68 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -801,6 +801,7 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), + crate::Expression::MulAdd { a, b: _, c: _ } => past(a)?.clone(), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 95ae40dcdb..523ed030f7 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -29,6 +29,7 @@ bitflags::bitflags! { const WORK_GROUP_BARRIER = 0x1; const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; + const COOP_OPS = 0x8; } } @@ -822,6 +823,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::MulAdd { a, b, c } => Uniformity { + non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), + requirements: UniformityRequirements::COOP_OPS, + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 8bb9af142b..01541ca65e 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -141,6 +141,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), + #[error("Invalid operand for MulAdd")] + InvalidMulAddOperand, } #[derive(Clone, Debug, thiserror::Error)] @@ -1267,6 +1269,39 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, + E::MulAdd { a, b, c } => { + match resolver[a] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::A, + .. + } => {} + ref other => { + log::error!("A operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + match resolver[b] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::B, + .. + } => {} + ref other => { + log::error!("B operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + match resolver[c] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::C, + .. + } => {} + ref other => { + log::error!("C operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + ShaderStages::COMPUTE + } }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dc19e19176..60a4ec8815 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -796,9 +796,8 @@ impl super::Validator { | Ex::As { .. } | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } - | Ex::RayQueryVertexPositions { .. } => { - self.emit_expression(handle, context)? - } + | Ex::RayQueryVertexPositions { .. } + | Ex::MulAdd { .. } => self.emit_expression(handle, context)?, Ex::CallResult(_) | Ex::AtomicResult { .. } | Ex::WorkGroupUniformLoadResult { .. } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 2cfb32ebe1..cc58212312 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -648,6 +648,9 @@ impl super::Validator { } => { handle.check_dep(query)?; } + crate::Expression::MulAdd { a, b, c } => { + handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; + } } Ok(()) } diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 91e371b9fb..2380046fbc 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,7 +1,8 @@ var a: coop_mat8x8; -//var b: coop_mat8x8; +var b: coop_mat8x8; +var c: coop_mat8x8; @compute @workgroup_size(8, 8, 1) fn main() { - let a2 = a + a; + let d = coopMulAdd(a, b, c); } From e604132a9bfa3279e134fed56ada883527e99230 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 23 Sep 2025 00:22:46 -0700 Subject: [PATCH 07/12] coop: Implement Load/Store statement --- naga/src/back/dot/mod.rs | 22 ++- naga/src/back/glsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 4 +- naga/src/back/mod.rs | 12 -- naga/src/back/msl/writer.rs | 131 ++++++++++++++++-- naga/src/back/pipeline_constants.rs | 15 +- naga/src/back/spv/block.rs | 51 ++++++- naga/src/back/spv/instructions.rs | 34 +++++ naga/src/back/spv/mod.rs | 2 +- naga/src/back/spv/writer.rs | 9 +- naga/src/back/wgsl/writer.rs | 37 +++-- naga/src/common/wgsl/to_wgsl.rs | 19 --- naga/src/compact/expressions.rs | 8 +- naga/src/compact/statements.rs | 26 ++++ naga/src/front/spv/mod.rs | 1 + naga/src/front/wgsl/error.rs | 6 +- naga/src/front/wgsl/lower/construction.rs | 7 +- naga/src/front/wgsl/lower/mod.rs | 51 +++++-- naga/src/front/wgsl/parse/mod.rs | 10 +- naga/src/ir/mod.rs | 43 ++---- naga/src/proc/constant_evaluator.rs | 4 +- naga/src/proc/layouter.rs | 2 +- naga/src/proc/terminator.rs | 3 +- naga/src/proc/type_methods.rs | 19 ++- naga/src/proc/typifier.rs | 18 ++- naga/src/valid/analyzer.rs | 20 ++- naga/src/valid/expression.rs | 34 ++--- naga/src/valid/function.rs | 70 +++++++++- naga/src/valid/handles.rs | 14 +- naga/src/valid/type.rs | 4 +- naga/tests/in/wgsl/cooperative-matrix.toml | 2 +- naga/tests/in/wgsl/cooperative-matrix.wgsl | 8 +- .../ir/wgsl-cooperative-matrix.compact.ron | 129 ++++++++++++++--- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 129 ++++++++++++++--- .../out/spv/wgsl-cooperative-matrix.spvasm | 85 ++++++++++-- 35 files changed, 798 insertions(+), 234 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index d23641f7a0..3dafc274ca 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -403,6 +403,24 @@ impl StatementGraph { }, } } + S::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major: _, + } => { + self.dependencies.push((id, target, "target")); + self.dependencies.push((id, pointer, "pointer")); + if let Some(stride) = stride { + self.dependencies.push((id, stride, "stride")); + } + if store { + "Store" + } else { + "Load" + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -742,11 +760,11 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } - E::MulAdd { a, b, c } => { + E::CooperativeMultiplyAdd { a, b, c } => { edges.insert("a", a); edges.insert("b", b); edges.insert("c", c); - ("MulAdd".into(), 6) + ("cooperativeMultiplyAdd".into(), 4) } }; diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 03483cd3d3..8a015fccd9 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2805,6 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { .. } => unimplemented!(), } Ok(()) @@ -4342,7 +4343,7 @@ impl<'a, W: Write> Writer<'a, W> { // not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } - | Expression::MulAdd { .. } => unreachable!(), + | Expression::CooperativeMultiplyAdd { .. } => unreachable!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 0c09e7fbdb..edf76e2c20 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2747,6 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { .. } => unimplemented!(), } Ok(()) @@ -4275,7 +4276,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } // Not supported yet - Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => { + Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeMultiplyAdd { .. } => { unreachable!() } // Nothing to do here, since call expression already cached diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 0d13d63dd9..0fe8e9274f 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -311,18 +311,6 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } } -impl crate::TypeInner { - /// Returns true if this is a handle to a type rather than the type directly. - pub const fn is_handle(&self) -> bool { - match *self { - crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } - | crate::TypeInner::AccelerationStructure { .. } => true, - _ => false, - } - } -} - impl crate::Statement { /// Returns true if the statement directly terminates the current block. /// diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 180021a6a7..4e5020cf4d 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp /// allowing them to be conveniently passed to user-defined or wrapper /// functions. The struct is declared in [`Writer::write_type_defs`]. pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper"; +pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// @@ -483,6 +484,12 @@ enum WrappedFunction { ImageQuerySize { class: crate::ImageClass, }, + CooperativeMultiplyAdd { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + intermediate: crate::CooperativeSize, + scalar: crate::Scalar, + }, } pub struct Writer { @@ -543,14 +550,6 @@ impl crate::Scalar { } } -impl crate::CooperativeScalar { - const fn to_msl_name(self) -> &'static str { - match self { - Self::F32 => "float", - } - } -} - const fn separate(need_separator: bool) -> &'static str { if need_separator { "," @@ -2842,12 +2841,14 @@ impl Writer { } write!(self.out, "}}")?; } - crate::Expression::MulAdd { a, b, c } => { - self.put_expression(a, context, false)?; - write!(self.out, " * ")?; - self.put_expression(b, context, false)?; - write!(self.out, " + ")?; - self.put_expression(c, context, false)?; + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; + self.put_expression(a, context, true)?; + write!(self.out, ", ")?; + self.put_expression(b, context, true)?; + write!(self.out, ", ")?; + self.put_expression(c, context, true)?; + write!(self.out, ")")?; } } Ok(()) @@ -4230,6 +4231,49 @@ impl Writer { } writeln!(self.out, ");")?; } + crate::Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let op_str = if store { "store" } else { "load" }; + write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?; + self.put_expression(target, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(pointer, &context.expression, true)?; + if stride.is_some() || row_major { + write!(self.out, ", ")?; + match stride { + Some(expression) => { + self.put_expression(expression, &context.expression, true)?; + } + None => { + let default_stride = match *context.expression.resolve_type(target) + { + crate::TypeInner::CooperativeMatrix { + columns, rows, .. + } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + write!(self.out, "{default_stride}")?; + } + } + } + if row_major { + let matrix_origin = "0"; + let transpose = true; + write!(self.out, ", {matrix_origin}, {transpose}")?; + } + writeln!(self.out, ");")?; + } } } @@ -6286,6 +6330,62 @@ template Ok(()) } + fn write_wrapped_cooperative_multiply_add( + &mut self, + module: &crate::Module, + func_ctx: &back::FunctionCtx, + a: Handle, + b: Handle, + ) -> BackendResult { + let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + .. + } => (columns, rows, scalar), + _ => unreachable!(), + }; + let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) { + crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), + _ => unreachable!(), + }; + let wrapped = WrappedFunction::CooperativeMultiplyAdd { + columns: b_c, + rows: a_r, + intermediate: a_c, + scalar, + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + let scalar_name = match scalar.width { + 2 => "half", + 4 => "float", + 8 => "double", + _ => unreachable!(), + }; + writeln!( + self.out, + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", + b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, + )?; + let l1 = back::Level(1); + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", + b_c as u32, a_r as u32 + )?; + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);" + )?; + writeln!(self.out, "{l1}return d;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, @@ -6360,6 +6460,9 @@ template crate::Expression::ImageQuery { image, query } => { self.write_wrapped_image_query(module, func_ctx, image, query)?; } + crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { + self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?; + } _ => {} } } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 852683d52a..e4dd2ba551 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,7 +633,7 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } - Expression::MulAdd { + Expression::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, @@ -844,6 +844,19 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } + Statement::CooperativeLoadStore { + store: _, + ref mut target, + ref mut pointer, + ref mut stride, + row_major: _, + } => { + adjust(target); + adjust(pointer); + if let Some(ref mut stride) = *stride { + adjust(stride); + } + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 3b5c7ddd1e..d9af863e5a 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1805,14 +1805,21 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } - crate::Expression::MulAdd { a, b, c } => { + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + self.writer.require_any( + "CooperativeMatrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + let a_id = self.cached[a]; + let b_id = self.cached[b]; + let c_id = self.cached[c]; let id = self.gen_id(); block.body.push(Instruction::coop_mul_add( result_type_id, id, - self.cached[a], - self.cached[b], - self.cached[c], + a_id, + b_id, + c_id, )); id } @@ -3677,6 +3684,42 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let layout = if row_major { + spirv::CooperativeMatrixLayout::RowMajorKHR + } else { + spirv::CooperativeMatrixLayout::ColumnMajorKHR + }; + let layout_id = self.get_index_constant(layout as u32); + let stride_id = stride.map(|exp| self.cached[exp]); + if store { + block.body.push(Instruction::coop_store( + self.cached[target], + self.cached[pointer], + layout_id, + stride_id, + )); + } else { + let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); + let id = self.gen_id(); + block.body.push(Instruction::coop_load( + result_type_id, + id, + self.cached[pointer], + layout_id, + stride_id, + )); + block + .body + .push(Instruction::store(self.cached[target], id, None)); + } + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 3091b6cfee..419c276fc4 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1247,6 +1247,40 @@ impl super::Instruction { } // Cooperative operations + pub(super) fn coop_load( + result_type_id: Word, + id: Word, + pointer_id: Word, + layout_id: Word, + stride_id: Option, + ) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + instruction.add_operand(layout_id); + if let Some(stride_id) = stride_id { + instruction.add_operand(stride_id); + } + + instruction + } + pub(super) fn coop_store( + id: Word, + pointer_id: Word, + layout_id: Word, + stride_id: Option, + ) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); + instruction.add_operand(pointer_id); + instruction.add_operand(id); + instruction.add_operand(layout_id); + if let Some(stride_id) = stride_id { + instruction.add_operand(stride_id); + } + + instruction + } pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); instruction.set_type(result_type_id); diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index f2622c466a..90f90f30ee 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -346,7 +346,7 @@ enum CooperativeType { Matrix { columns: crate::CooperativeSize, rows: crate::CooperativeSize, - scalar: crate::CooperativeScalar, + scalar: crate::Scalar, role: crate::CooperativeRole, }, } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 12392ac29e..ad08818639 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -376,12 +376,6 @@ impl Writer { }) } - pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word { - match scalar { - crate::CooperativeScalar::F32 => self.get_f32_type_id(), - } - } - pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let f32_id = self.get_f32_type_id(); self.get_pointer_type_id(f32_id, class) @@ -1406,7 +1400,8 @@ impl Writer { scalar, role, } => { - let scalar_id = self.get_cooperative_type_id(scalar); + let scalar_id = + self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar))); let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let columns_id = self.get_index_constant(columns as u32); let rows_id = self.get_index_constant(rows as u32); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index a1e9b31847..d05a461f7c 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -984,6 +984,25 @@ impl Writer { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let op_str = if store { "Store" } else { "Load" }; + let suffix = if row_major { "T" } else { "" }; + write!(self.out, "coop{op_str}{suffix}(")?; + self.write_expr(module, target, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, pointer, func_ctx)?; + if let Some(stride) = stride { + write!(self.out, ", ")?; + self.write_expr(module, stride, func_ctx)?; + } + write!(self.out, ")")? + } } Ok(()) @@ -1685,15 +1704,6 @@ impl Writer { write!(self.out, ")")? } - Expression::MulAdd { a, b, c } => { - write!(self.out, "mulAdd(")?; - self.write_expr(module, a, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, b, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, c, func_ctx)?; - write!(self.out, ")")? - } // Not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => unreachable!(), @@ -1704,6 +1714,15 @@ impl Writer { | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} + Expression::CooperativeMultiplyAdd { a, b, c } => { + write!(self.out, "coopMultiplyAdd(")?; + self.write_expr(module, a, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, b, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, c, func_ctx)?; + write!(self.out, ")")?; + } } Ok(()) diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 1a3f5e5e17..1cdf3eb5cf 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -299,25 +299,6 @@ impl TryToWgsl for crate::Scalar { } } -impl TryToWgsl for crate::CooperativeScalar { - const DESCRIPTION: &'static str = "cooperative scalar type"; - - fn try_to_wgsl(self) -> Option<&'static str> { - use crate::CooperativeScalar; - - Some(match self { - CooperativeScalar::F32 => "f32", - }) - } - - fn to_wgsl_for_diagnostics(self) -> String { - match self.try_to_wgsl() { - Some(static_string) => static_string.to_string(), - None => unreachable!(), - } - } -} - impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { match self { diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 98f3bbc3c9..2b2117cc16 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,8 +253,10 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } - Ex::MulAdd { a, b, c } => { - self.expressions_used.insert_iter([a, b, c]); + Ex::CooperativeMultiplyAdd { a, b, c } => { + self.expressions_used.insert(a); + self.expressions_used.insert(b); + self.expressions_used.insert(c); } } } @@ -422,7 +424,7 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), - Ex::MulAdd { + Ex::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 39d6065f5f..4124fe907e 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -152,6 +152,19 @@ impl FunctionTracer<'_> { self.expressions_used.insert(argument); self.expressions_used.insert(result); } + St::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + self.expressions_used.insert(target); + self.expressions_used.insert(pointer); + if let Some(stride) = stride { + self.expressions_used.insert(stride); + } + } // Trivial statements. St::Break @@ -371,6 +384,19 @@ impl FunctionMap { adjust(argument); adjust(result); } + St::CooperativeLoadStore { + store: _, + ref mut target, + ref mut pointer, + ref mut stride, + row_major: _, + } => { + adjust(target); + adjust(pointer); + if let Some(ref mut stride) = *stride { + adjust(stride); + } + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 803e52553d..e5ae2ec2e8 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4654,6 +4654,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::CooperativeLoadStore { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 8c749acc73..f0d6a4b848 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -413,7 +413,7 @@ pub(crate) enum Error<'a> { span: Span, }, UnderspecifiedCooperativeMatrix, - UnknownCooperativeScalar(Span), + UnsupportedCooperativeScalar(Span), } impl From for Error<'_> { @@ -1393,8 +1393,8 @@ impl<'a> Error<'a> { labels: vec![], notes: vec![format!("must be F32")], }, - Error::UnknownCooperativeScalar(span) => ParseError { - message: "unknown cooperative scalar type".into(), + Error::UnsupportedCooperativeScalar(span) => ParseError { + message: "cooperative scalar type is not supported".into(), labels: vec![(span, "type needs the scalar type specified".into())], notes: vec![format!("must be F32")], }, diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 9ac11bfc98..2159ef01ad 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -650,11 +650,8 @@ impl<'source> Lowerer<'source, '_> { } => { let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; let scalar = match ctx.module.types[ty].inner { - crate::TypeInner::Scalar(crate::Scalar { - kind: crate::ScalarKind::Float, - width: 4, - }) => crate::CooperativeScalar::F32, - _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + crate::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), }; let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix { columns, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 511150e2ee..0240179b64 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1677,8 +1677,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .as_expression(block, &mut emitter) .interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?; block.extend(emitter.finish(&ctx.function.expressions)); - ctx.local_table - .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); + let typed = if ctx.module.types[ty].inner.is_handle() { + Typed::Plain(handle) + } else { + Typed::Reference(handle) + }; + ctx.local_table.insert(v.handle, Declared::Runtime(typed)); match initializer { Some(initializer) => ir::Statement::Store { @@ -2134,8 +2138,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr = match *global { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); - match ctx.module.global_variables[handle].space { + let v = &ctx.module.global_variables[handle]; + let force_value = ctx.module.types[v.ty].inner.is_handle(); + match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), + _ if force_value => Typed::Plain(expr), _ => Typed::Reference(expr), } } @@ -3128,14 +3135,41 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "coopMulAdd" => { + "coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let target = self.expression(args.next()?, ctx)?; + let pointer = self.expression(args.next()?, ctx)?; + let stride = if args.total_args > 2 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let store = function.name.contains("Store"); + let row_major = function.name.ends_with("T"); + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + }, + span, + ); + return Ok(None); + } + "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, span); let a = self.expression(args.next()?, ctx)?; let b = self.expression(args.next()?, ctx)?; let c = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::MulAdd { a, b, c } + ir::Expression::CooperativeMultiplyAdd { a, b, c } } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) @@ -3971,11 +4005,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } => { let ty = self.resolve_ast_type(ty, ctx)?; let scalar = match ctx.module.types[ty].inner { - ir::TypeInner::Scalar(crate::Scalar { - kind: crate::ScalarKind::Float, - width: 4, - }) => crate::CooperativeScalar::F32, - _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + ir::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), }; ir::TypeInner::CooperativeMatrix { columns, diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 49d7eaab25..576bd9c977 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -658,12 +658,10 @@ impl Parser { ty_span: Span::UNDEFINED, })) } - "coop_mat8x8" => { - return Ok(Some(ast::ConstructorType::PartialCooperativeMatrix { - columns: crate::CooperativeSize::Eight, - rows: crate::CooperativeSize::Eight, - })) - } + "coop_mat8x8" => ast::ConstructorType::PartialCooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Eight, + }, "array" => ast::ConstructorType::PartialArray, "atomic" | "binding_array" diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 6679ef88be..55d381b911 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -474,33 +474,6 @@ pub enum ScalarKind { AbstractFloat, } -/// Primitive type for a cooperative scalar. -#[repr(u8)] -#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum CooperativeScalar { - F32, -} - -impl CooperativeScalar { - pub const fn width(&self) -> Bytes { - match *self { - Self::F32 => 4, - } - } - - pub const fn to_scalar(&self) -> Scalar { - match *self { - Self::F32 => Scalar { - kind: ScalarKind::Float, - width: 4, - }, - } - } -} - /// Role of a cooperative variable in the equation "A * B + C" #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -766,7 +739,7 @@ pub enum TypeInner { CooperativeMatrix { columns: CooperativeSize, rows: CooperativeSize, - scalar: CooperativeScalar, + scalar: Scalar, role: CooperativeRole, }, /// Atomic scalar. @@ -1791,10 +1764,8 @@ pub enum Expression { /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, - /// Return a * b + c. - /// Currently only supported for [`TypeInner::CooperativeMatrix`] types, - /// where it's only valid in uniform control flow. - MulAdd { + /// Compute `a * b + c` + CooperativeMultiplyAdd { a: Handle, b: Handle, c: Handle, @@ -2240,6 +2211,14 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, + /// Load from or store into a cooperative primitive. + CooperativeLoadStore { + store: bool, + target: Handle, + pointer: Handle, + stride: Option>, + row_major: bool, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index d571183ab3..b0508193bf 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -973,7 +973,9 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } - Expression::MulAdd { .. } => Err(ConstantEvaluatorError::CooperativeOperation), + Expression::CooperativeMultiplyAdd { .. } => { + Err(ConstantEvaluatorError::CooperativeOperation) + } } } diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 7f9380d766..5165ac7a01 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -224,7 +224,7 @@ impl Layouter { scalar, role: _, } => { - let alignment = Alignment::new(scalar.width() as u32) + let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index b29ccb054a..a670694f23 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -43,7 +43,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } | S::ControlBarrier(_) - | S::MemoryBarrier(_)), + | S::MemoryBarrier(_) + | S::CooperativeLoadStore { .. }), ) | None => block.push(S::Return { value: None }, Default::default()), } diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c4a9091c74..136ea29218 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -115,7 +115,7 @@ impl crate::TypeInner { match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), Ti::Matrix { scalar, .. } => Some(scalar), - Ti::CooperativeMatrix { scalar, .. } => Some(scalar.to_scalar()), + Ti::CooperativeMatrix { scalar, .. } => Some(scalar), _ => None, } } @@ -183,14 +183,25 @@ impl crate::TypeInner { pub fn is_atomic_pointer(&self, types: &crate::UniqueArena) -> bool { match *self { - crate::TypeInner::Pointer { base, .. } => match types[base].inner { - crate::TypeInner::Atomic { .. } => true, + Self::Pointer { base, .. } => match types[base].inner { + Self::Atomic { .. } => true, _ => false, }, _ => false, } } + /// Returns true if a variable of this type is a handle. + pub const fn is_handle(&self) -> bool { + match *self { + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure { .. } + | Self::CooperativeMatrix { .. } => true, + _ => false, + } + } + /// Attempt to calculate the size of this type. Returns `None` if the size /// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`]. pub fn try_size(&self, gctx: super::GlobalCtx) -> Option { @@ -208,7 +219,7 @@ impl crate::TypeInner { rows, scalar, role: _, - } => Some(columns as u32 * rows as u32 * scalar.width() as u32), + } => Some(columns as u32 * rows as u32 * scalar.width as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { base: _, diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index f90f36de68..8e323d7724 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -454,7 +454,8 @@ impl<'a> ResolveContext<'a> { } crate::Expression::GlobalVariable(h) => { let var = &self.global_vars[h]; - if var.space == crate::AddressSpace::Handle { + let ty = &types[var.ty].inner; + if var.space == crate::AddressSpace::Handle || ty.is_handle() { TypeResolution::Handle(var.ty) } else { TypeResolution::Value(Ti::Pointer { @@ -465,10 +466,15 @@ impl<'a> ResolveContext<'a> { } crate::Expression::LocalVariable(h) => { let var = &self.local_vars[h]; - TypeResolution::Value(Ti::Pointer { - base: var.ty, - space: crate::AddressSpace::Function, - }) + let ty = &types[var.ty].inner; + if ty.is_handle() { + TypeResolution::Handle(var.ty) + } else { + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) + } } crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { Ti::Pointer { base, space: _ } => { @@ -801,7 +807,7 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), - crate::Expression::MulAdd { a, b: _, c: _ } => past(a)?.clone(), + crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 523ed030f7..9daf2314b5 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -823,7 +823,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, - E::MulAdd { a, b, c } => Uniformity { + E::CooperativeMultiplyAdd { a, b, c } => Uniformity { non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), requirements: UniformityRequirements::COOP_OPS, }, @@ -1156,6 +1156,24 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + if let Some(stride) = stride { + let _ = self.add_ref(stride); + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + } + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 01541ca65e..466ca26b60 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -141,8 +141,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), - #[error("Invalid operand for MulAdd")] - InvalidMulAddOperand, + #[error("Invalid operand for cooperative op")] + InvalidCooperativeOperand(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -888,24 +888,10 @@ impl super::Validator { }, ) => columns == rows && scalar1 == scalar2 && role1 == role2, // Scalar * coop matrix. - ( - &Ti::Scalar(Sc { - kind: Sk::Float, .. - }), - &Ti::CooperativeMatrix { - scalar: crate::CooperativeScalar::F32, - .. - }, - ) - | ( - &Ti::CooperativeMatrix { - scalar: crate::CooperativeScalar::F32, - .. - }, - &Ti::Scalar(Sc { - kind: Sk::Float, .. - }), - ) => true, + (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. }) + | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => { + s1 == s2 + } _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); @@ -1269,7 +1255,7 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, - E::MulAdd { a, b, c } => { + E::CooperativeMultiplyAdd { a, b, c } => { match resolver[a] { Ti::CooperativeMatrix { role: crate::CooperativeRole::A, @@ -1277,7 +1263,7 @@ impl super::Validator { } => {} ref other => { log::error!("A operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(a)); } } match resolver[b] { @@ -1287,7 +1273,7 @@ impl super::Validator { } => {} ref other => { log::error!("B operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(b)); } } match resolver[c] { @@ -1297,7 +1283,7 @@ impl super::Validator { } => {} ref other => { log::error!("C operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(c)); } } ShaderStages::COMPUTE diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 60a4ec8815..136c4b17f1 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1,6 +1,5 @@ use alloc::{format, string::String}; -use super::validate_atomic_compare_exchange_struct; use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, @@ -213,6 +212,10 @@ pub enum FunctionError { WorkgroupUniformLoadInvalidPointer(Handle), #[error("Subgroup operation is invalid")] InvalidSubgroup(#[from] SubgroupError), + #[error("Invalid target type for a cooperative store")] + InvalidCooperativeStoreTarget(Handle), + #[error("Cooperative load/store data pointer has invalid type")] + InvalidCooperativeDataPointer(Handle), #[error("Emit statement should not cover \"result\" expressions like {0:?}")] EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] @@ -576,7 +579,7 @@ impl super::Validator { .with_span_handle(result, context.expressions) .into_other()); }; - if !validate_atomic_compare_exchange_struct( + if !super::validate_atomic_compare_exchange_struct( context.types, members, |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar), @@ -797,7 +800,9 @@ impl super::Validator { | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } | Ex::RayQueryVertexPositions { .. } - | Ex::MulAdd { .. } => self.emit_expression(handle, context)?, + | Ex::CooperativeMultiplyAdd { .. } => { + self.emit_expression(handle, context)? + } Ex::CallResult(_) | Ex::AtomicResult { .. } | Ex::WorkGroupUniformLoadResult { .. } @@ -1073,7 +1078,7 @@ impl super::Validator { } else if let Some(tr) = pointer_base_tr { context.compare_types(value_tr, &tr) } else { - false + value_ty.is_handle() }; if !good { @@ -1617,6 +1622,63 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } + S::CooperativeLoadStore { + store, + target, + pointer, + stride: _, + row_major: _, + } => { + stages &= super::ShaderStages::COMPUTE; + + let target_scalar = + match *context.resolve_type_inner(target, &self.valid_expression_set)? { + Ti::CooperativeMatrix { scalar, .. } => scalar, + ref other => { + log::error!("Target operand type: {other:?}"); + return Err(FunctionError::InvalidCooperativeStoreTarget(target) + .with_span_handle(target, context.expressions)); + } + }; + + let ty_inner = + context.resolve_type_inner(pointer, &self.valid_expression_set)?; + //TODO: validate stride + let (pty_array, space) = match *ty_inner { + crate::TypeInner::Pointer { base, space } => (base, space), + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + let pty_scalar = match context.types[pty_array].inner { + crate::TypeInner::Array { + base, + size: _, + stride: _, + } => base, + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + let space = match context.types[pty_scalar].inner { + crate::TypeInner::Scalar(s) if s == target_scalar => space, + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + + if store && !space.access().contains(crate::StorageAccess::STORE) { + return Err( + FunctionError::InvalidStorePointer(pointer).with_span_static( + context.expressions.get_span(pointer), + "writing to this location is not permitted", + ), + ); + } + } } } Ok(BlockInfo { stages }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index cc58212312..15a778778c 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -648,7 +648,7 @@ impl super::Validator { } => { handle.check_dep(query)?; } - crate::Expression::MulAdd { a, b, c } => { + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; } } @@ -839,6 +839,18 @@ impl super::Validator { validate_expr(result)?; Ok(()) } + crate::Statement::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + validate_expr(target)?; + validate_expr(pointer)?; + validate_expr_opt(stride)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 7a8524bd2e..93cdae34e1 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -422,7 +422,7 @@ impl super::Validator { role: _, } => { self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; - if scalar != crate::CooperativeScalar::F32 { + if scalar.kind != crate::ScalarKind::Float || scalar.width != 4 { return Err(TypeError::MatrixElementNotFloat); } TypeInfo::new( @@ -433,7 +433,7 @@ impl super::Validator { | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, - Alignment::from_width(scalar.width()), + Alignment::from_width(scalar.width), ) } Ti::Atomic(scalar) => { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index c06d67d21a..4a3be8b94e 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,4 +1,4 @@ -targets = "SPIRV" +targets = "IR | SPIRV | METAL" god_mode = true [spv] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 2380046fbc..24ecb9a2b3 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,8 +1,12 @@ var a: coop_mat8x8; var b: coop_mat8x8; -var c: coop_mat8x8; +@group(0) @binding(0) +var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - let d = coopMulAdd(a, b, c); + var c = coop_mat8x8(); + coopLoad(c, &ext); + var d = coopMultiplyAdd(a, b, c); + coopStore(c, &ext); } diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 1298f69e2c..31d47d603a 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -1,14 +1,56 @@ ( types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), ( name: None, inner: CooperativeMatrix( columns: Eight, rows: Eight, - scalar: F32, + scalar: ( + kind: Float, + width: 4, + ), role: A, ), ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), ], special_types: ( ray_desc: None, @@ -25,7 +67,26 @@ name: Some("a"), space: Private, binding: None, - ty: 0, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, init: None, ), ], @@ -42,34 +103,56 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [ - GlobalVariable(0), - Load( - pointer: 0, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: Some(0), ), - GlobalVariable(0), - Load( - pointer: 2, + ( + name: Some("d"), + ty: 4, + init: None, ), - Binary( - op: Add, - left: 1, - right: 3, + ], + expressions: [ + ZeroValue(4), + LocalVariable(0), + GlobalVariable(2), + GlobalVariable(0), + GlobalVariable(1), + CooperativeMultiplyAdd( + a: 3, + b: 4, + c: 1, ), + LocalVariable(1), + GlobalVariable(2), ], - named_expressions: { - 4: "a2", - }, + named_expressions: {}, body: [ + CooperativeLoadStore( + store: false, + target: 1, + pointer: 2, + stride: None, + row_major: false, + ), Emit(( - start: 1, - end: 2, - )), - Emit(( - start: 3, - end: 5, + start: 5, + end: 6, )), + Store( + pointer: 6, + value: 5, + ), + CooperativeLoadStore( + store: true, + target: 1, + pointer: 7, + stride: None, + row_major: false, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 1298f69e2c..31d47d603a 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -1,14 +1,56 @@ ( types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), ( name: None, inner: CooperativeMatrix( columns: Eight, rows: Eight, - scalar: F32, + scalar: ( + kind: Float, + width: 4, + ), role: A, ), ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), ], special_types: ( ray_desc: None, @@ -25,7 +67,26 @@ name: Some("a"), space: Private, binding: None, - ty: 0, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, init: None, ), ], @@ -42,34 +103,56 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [ - GlobalVariable(0), - Load( - pointer: 0, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: Some(0), ), - GlobalVariable(0), - Load( - pointer: 2, + ( + name: Some("d"), + ty: 4, + init: None, ), - Binary( - op: Add, - left: 1, - right: 3, + ], + expressions: [ + ZeroValue(4), + LocalVariable(0), + GlobalVariable(2), + GlobalVariable(0), + GlobalVariable(1), + CooperativeMultiplyAdd( + a: 3, + b: 4, + c: 1, ), + LocalVariable(1), + GlobalVariable(2), ], - named_expressions: { - 4: "a2", - }, + named_expressions: {}, body: [ + CooperativeLoadStore( + store: false, + target: 1, + pointer: 2, + stride: None, + row_major: false, + ), Emit(( - start: 1, - end: 2, - )), - Emit(( - start: 3, - end: 5, + start: 5, + end: 6, )), + Store( + pointer: 6, + value: 5, + ), + CooperativeLoadStore( + store: true, + target: 1, + pointer: 7, + stride: None, + row_major: false, + ), Return( value: None, ), diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 33e7477e5d..0e8a882994 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,17 +1,82 @@ ; SPIR-V -; Version: 1.1 +; Version: 1.4 ; Generator: rspirv -; Bound: 7 +; Bound: 37 OpCapability Shader +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" %1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %4 "main" -OpExecutionMode %4 LocalSize 8 8 1 +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %25 "main" %15 %18 %21 +OpExecutionMode %25 LocalSize 8 8 1 +%3 = OpString "cooperative-matrix.wgsl" +OpSource Unknown 0 %3 "var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c = coop_mat8x8(); + coopLoad(c, &ext); + var d = coopMultiplyAdd(a, b, c); + coopStore(c, &ext); +} +" +OpName %15 "a" +OpName %18 "b" +OpName %21 "ext" +OpName %25 "main" +OpName %30 "c" +OpName %32 "d" +OpDecorate %12 ArrayStride 4 +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 0 +OpDecorate %22 Block +OpMemberDecorate %22 0 Offset 0 %2 = OpTypeVoid -%5 = OpTypeFunction %2 -%4 = OpFunction %2 None %5 -%3 = OpLabel -OpBranch %6 -%6 = OpLabel +%4 = OpTypeFloat 32 +%7 = OpTypeInt 32 0 +%6 = OpConstant %7 3 +%8 = OpConstant %7 8 +%9 = OpConstant %7 0 +%5 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %9 +%11 = OpConstant %7 1 +%10 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %11 +%12 = OpTypeRuntimeArray %4 +%14 = OpConstant %7 2 +%13 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %14 +%16 = OpTypePointer Private %5 +%17 = OpConstantNull %5 +%15 = OpVariable %16 Private %17 +%19 = OpTypePointer Private %10 +%20 = OpConstantNull %10 +%18 = OpVariable %19 Private %20 +%22 = OpTypeStruct %12 +%23 = OpTypePointer StorageBuffer %22 +%21 = OpVariable %23 StorageBuffer +%26 = OpTypeFunction %2 +%27 = OpTypePointer StorageBuffer %12 +%29 = OpConstantNull %13 +%31 = OpTypePointer Function %13 +%33 = OpConstantNull %13 +%25 = OpFunction %2 None %26 +%24 = OpLabel +%30 = OpVariable %31 Function %29 +%32 = OpVariable %31 Function %33 +%28 = OpAccessChain %27 %21 %9 +OpBranch %34 +%34 = OpLabel +OpLine %3 9 5 +%35 = OpCooperativeMatrixLoadKHR %13 %28 %11 +OpStore %30 %35 +OpLine %3 10 13 +%36 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 +OpLine %3 10 5 +OpStore %32 %36 +OpLine %3 11 5 +OpCooperativeMatrixStoreKHR %28 %30 %11 OpReturn OpFunctionEnd \ No newline at end of file From 7f365363e36b4a739f471f7a011d4c0701f5421b Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 21:32:39 -0700 Subject: [PATCH 08/12] coop: fixes and changelog --- CHANGELOG.md | 6 ++-- naga/src/back/spv/block.rs | 16 ++++++++-- naga/src/valid/function.rs | 20 +++--------- naga/tests/in/wgsl/cooperative-matrix.wgsl | 4 +-- .../ir/wgsl-cooperative-matrix.compact.ron | 32 ++++++++++++++----- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 32 ++++++++++++++----- .../tests/out/msl/wgsl-cooperative-matrix.msl | 31 ++++++++++++++++++ .../out/spv/wgsl-cooperative-matrix.spvasm | 22 ++++++++----- 8 files changed, 117 insertions(+), 46 deletions(-) create mode 100644 naga/tests/out/msl/wgsl-cooperative-matrix.msl diff --git a/CHANGELOG.md b/CHANGELOG.md index fb156281c0..69759b8e2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ Bottom level categories: #### Deferred command buffer actions: `map_buffer_on_submit` and `on_submitted_work_done` -You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes. +You may schedule buffer mapping and a submission-complete callback to run automatically after you submit, directly from encoders, command buffers, and passes. ```rust // Record some GPU work so the submission isn't empty and touches `buffer`. @@ -150,7 +150,7 @@ By @cwfitzgerald in [#8163](https://github.com/gfx-rs/wgpu/pull/8163). #### Multi-draw indirect is now unconditionally supported when indirect draws are supported -We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms. +We have removed `Features::MULTI_DRAW_INDIRECT` as it was unconditionally available on all platforms. `RenderPass::multi_draw_indirect` is now available if the device supports downlevel flag `DownlevelFlags::INDIRECT_EXECUTION`. If you are using spirv-passthrough with multi-draw indirect and `gl_DrawID`, you can know if `MULTI_DRAW_INDIRECT` is being emulated @@ -166,6 +166,8 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162). - Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386). +- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). + ### Changes #### General diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index d9af863e5a..ac66a897cc 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3691,6 +3691,18 @@ impl BlockContext<'_> { stride, row_major, } => { + let pointer_id = match self.write_access_chain( + pointer, + &mut block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Copperative load/store out-of-bounds handling", + )); + } + }; let layout = if row_major { spirv::CooperativeMatrixLayout::RowMajorKHR } else { @@ -3701,7 +3713,7 @@ impl BlockContext<'_> { if store { block.body.push(Instruction::coop_store( self.cached[target], - self.cached[pointer], + pointer_id, layout_id, stride_id, )); @@ -3711,7 +3723,7 @@ impl BlockContext<'_> { block.body.push(Instruction::coop_load( result_type_id, id, - self.cached[pointer], + pointer_id, layout_id, stride_id, )); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 136c4b17f1..e1263e88fb 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1641,32 +1641,20 @@ impl super::Validator { } }; - let ty_inner = - context.resolve_type_inner(pointer, &self.valid_expression_set)?; + let ty_inner = context.resolve_pointer_type(pointer); //TODO: validate stride - let (pty_array, space) = match *ty_inner { + let (pty_scalar, space) = match *ty_inner { crate::TypeInner::Pointer { base, space } => (base, space), _ => { return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) - } - }; - let pty_scalar = match context.types[pty_array].inner { - crate::TypeInner::Array { - base, - size: _, - stride: _, - } => base, - _ => { - return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) + .with_span_handle(pointer, context.expressions)); } }; let space = match context.types[pty_scalar].inner { crate::TypeInner::Scalar(s) if s == target_scalar => space, _ => { return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) + .with_span_handle(pointer, context.expressions)); } }; diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 24ecb9a2b3..e65fe0d589 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -6,7 +6,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { var c = coop_mat8x8(); - coopLoad(c, &ext); + coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext); + coopStore(c, &ext[0]); } diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 31d47d603a..7582580360 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -119,40 +119,56 @@ ZeroValue(4), LocalVariable(0), GlobalVariable(2), + AccessIndex( + base: 2, + index: 4, + ), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 3, - b: 4, + a: 4, + b: 5, c: 1, ), LocalVariable(1), GlobalVariable(2), + AccessIndex( + base: 8, + index: 0, + ), ], named_expressions: {}, body: [ CooperativeLoadStore( store: false, target: 1, - pointer: 2, + pointer: 3, stride: None, row_major: false, ), Emit(( - start: 5, - end: 6, + start: 3, + end: 4, + )), + Emit(( + start: 6, + end: 7, )), Store( - pointer: 6, - value: 5, + pointer: 7, + value: 6, ), CooperativeLoadStore( store: true, target: 1, - pointer: 7, + pointer: 9, stride: None, row_major: false, ), + Emit(( + start: 9, + end: 10, + )), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 31d47d603a..7582580360 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -119,40 +119,56 @@ ZeroValue(4), LocalVariable(0), GlobalVariable(2), + AccessIndex( + base: 2, + index: 4, + ), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 3, - b: 4, + a: 4, + b: 5, c: 1, ), LocalVariable(1), GlobalVariable(2), + AccessIndex( + base: 8, + index: 0, + ), ], named_expressions: {}, body: [ CooperativeLoadStore( store: false, target: 1, - pointer: 2, + pointer: 3, stride: None, row_major: false, ), Emit(( - start: 5, - end: 6, + start: 3, + end: 4, + )), + Emit(( + start: 6, + end: 7, )), Store( - pointer: 6, - value: 5, + pointer: 7, + value: 6, ), CooperativeLoadStore( store: true, target: 1, - pointer: 7, + pointer: 9, stride: None, row_major: false, ), + Emit(( + start: 9, + end: 10, + )), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl new file mode 100644 index 0000000000..bed4406760 --- /dev/null +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -0,0 +1,31 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _mslBufferSizes { + uint size2; +}; + +typedef float type_3[1]; +metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const metal::simdgroup_float8x8& a, const metal::simdgroup_float8x8& b, const metal::simdgroup_float8x8& c) { + metal::simdgroup_float8x8 d; + metal::simdgroup_multiply_accumulate(d,a,b,c); + return d; +} + + +kernel void main_( + device type_3 const& ext [[user(fake0)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] +) { + metal::simdgroup_float8x8 a = {}; + metal::simdgroup_float8x8 b = {}; + metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; + metal::simdgroup_float8x8 d = {}; + metal::simdgroup_load(c, ext[4]); + d = NagaCooperativeMultiplyAdd(a, b, c); + metal::simdgroup_store(c, ext[0]); + return; +} diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 0e8a882994..a3626919a9 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 37 +; Bound: 41 OpCapability Shader OpCapability CooperativeMatrixKHR OpCapability VulkanMemoryModel @@ -20,9 +20,9 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { var c = coop_mat8x8(); - coopLoad(c, &ext); + coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext); + coopStore(c, &ext[0]); } " OpName %15 "a" @@ -62,6 +62,8 @@ OpMemberDecorate %22 0 Offset 0 %29 = OpConstantNull %13 %31 = OpTypePointer Function %13 %33 = OpConstantNull %13 +%35 = OpTypePointer StorageBuffer %4 +%36 = OpConstant %7 4 %25 = OpFunction %2 None %26 %24 = OpLabel %30 = OpVariable %31 Function %29 @@ -70,13 +72,17 @@ OpMemberDecorate %22 0 Offset 0 OpBranch %34 %34 = OpLabel OpLine %3 9 5 -%35 = OpCooperativeMatrixLoadKHR %13 %28 %11 -OpStore %30 %35 +%37 = OpAccessChain %35 %28 %36 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 +OpStore %30 %38 +OpLine %3 9 18 OpLine %3 10 13 -%36 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 +%39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 OpLine %3 10 5 -OpStore %32 %36 +OpStore %32 %39 OpLine %3 11 5 -OpCooperativeMatrixStoreKHR %28 %30 %11 +%40 = OpAccessChain %35 %28 %9 +OpCooperativeMatrixStoreKHR %40 %30 %11 +OpLine %3 11 19 OpReturn OpFunctionEnd \ No newline at end of file From 321ddf435200535fbf37e266557d716c86df59b3 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 22:36:48 -0700 Subject: [PATCH 09/12] coop: make stride non-optional --- naga/src/back/dot/mod.rs | 4 +- naga/src/back/msl/writer.rs | 26 +------------ naga/src/back/pipeline_constants.rs | 4 +- naga/src/back/spv/block.rs | 5 +-- naga/src/back/spv/instructions.rs | 19 ++-------- naga/src/back/wgsl/writer.rs | 6 +-- naga/src/compact/statements.rs | 8 +--- naga/src/front/wgsl/lower/mod.rs | 24 +++++++++--- naga/src/ir/mod.rs | 2 +- naga/src/valid/analyzer.rs | 22 +++++------ naga/src/valid/handles.rs | 2 +- naga/tests/in/wgsl/cooperative-matrix.toml | 2 +- .../ir/wgsl-cooperative-matrix.compact.ron | 38 ++++++++++--------- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 38 ++++++++++--------- .../tests/out/msl/wgsl-cooperative-matrix.msl | 4 +- .../out/spv/wgsl-cooperative-matrix.spvasm | 8 ++-- .../out/wgsl/wgsl-cooperative-matrix.wgsl | 13 +++++++ 17 files changed, 105 insertions(+), 120 deletions(-) create mode 100644 naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 3dafc274ca..358c4c35c4 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -412,9 +412,7 @@ impl StatementGraph { } => { self.dependencies.push((id, target, "target")); self.dependencies.push((id, pointer, "pointer")); - if let Some(stride) = stride { - self.dependencies.push((id, stride, "stride")); - } + self.dependencies.push((id, stride, "stride")); if store { "Store" } else { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 4e5020cf4d..9e520c4c13 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4243,30 +4243,8 @@ impl Writer { self.put_expression(target, &context.expression, true)?; write!(self.out, ", ")?; self.put_expression(pointer, &context.expression, true)?; - if stride.is_some() || row_major { - write!(self.out, ", ")?; - match stride { - Some(expression) => { - self.put_expression(expression, &context.expression, true)?; - } - None => { - let default_stride = match *context.expression.resolve_type(target) - { - crate::TypeInner::CooperativeMatrix { - columns, rows, .. - } => { - if row_major { - columns as u32 - } else { - rows as u32 - } - } - _ => 0, - }; - write!(self.out, "{default_stride}")?; - } - } - } + write!(self.out, ", ")?; + self.put_expression(stride, &context.expression, true)?; if row_major { let matrix_origin = "0"; let transpose = true; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index e4dd2ba551..5a9fd9558d 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -853,9 +853,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } => { adjust(target); adjust(pointer); - if let Some(ref mut stride) = *stride { - adjust(stride); - } + adjust(stride); } Statement::Break | Statement::Continue diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index ac66a897cc..0920a86f39 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3709,13 +3709,12 @@ impl BlockContext<'_> { spirv::CooperativeMatrixLayout::ColumnMajorKHR }; let layout_id = self.get_index_constant(layout as u32); - let stride_id = stride.map(|exp| self.cached[exp]); if store { block.body.push(Instruction::coop_store( self.cached[target], pointer_id, layout_id, - stride_id, + self.cached[stride], )); } else { let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); @@ -3725,7 +3724,7 @@ impl BlockContext<'_> { id, pointer_id, layout_id, - stride_id, + self.cached[stride], )); block .body diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 419c276fc4..22eaa99340 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1252,33 +1252,22 @@ impl super::Instruction { id: Word, pointer_id: Word, layout_id: Word, - stride_id: Option, + stride_id: Word, ) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer_id); instruction.add_operand(layout_id); - if let Some(stride_id) = stride_id { - instruction.add_operand(stride_id); - } - + instruction.add_operand(stride_id); instruction } - pub(super) fn coop_store( - id: Word, - pointer_id: Word, - layout_id: Word, - stride_id: Option, - ) -> Self { + pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); instruction.add_operand(pointer_id); instruction.add_operand(id); instruction.add_operand(layout_id); - if let Some(stride_id) = stride_id { - instruction.add_operand(stride_id); - } - + instruction.add_operand(stride_id); instruction } pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index d05a461f7c..95fc3af6c0 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -997,10 +997,8 @@ impl Writer { self.write_expr(module, target, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, pointer, func_ctx)?; - if let Some(stride) = stride { - write!(self.out, ", ")?; - self.write_expr(module, stride, func_ctx)?; - } + write!(self.out, ", ")?; + self.write_expr(module, stride, func_ctx)?; write!(self.out, ")")? } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 4124fe907e..5c36a40274 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -161,9 +161,7 @@ impl FunctionTracer<'_> { } => { self.expressions_used.insert(target); self.expressions_used.insert(pointer); - if let Some(stride) = stride { - self.expressions_used.insert(stride); - } + self.expressions_used.insert(stride); } // Trivial statements. @@ -393,9 +391,7 @@ impl FunctionMap { } => { adjust(target); adjust(pointer); - if let Some(ref mut stride) = *stride { - adjust(stride); - } + adjust(stride); } // Trivial statements. diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 0240179b64..7b2c6866cf 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3136,19 +3136,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(result)); } "coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => { + let store = function.name.contains("Store"); + let row_major = function.name.ends_with("T"); + let mut args = ctx.prepare_args(arguments, 2, span); let target = self.expression(args.next()?, ctx)?; let pointer = self.expression(args.next()?, ctx)?; let stride = if args.total_args > 2 { - Some(self.expression(args.next()?, ctx)?) + self.expression(args.next()?, ctx)? } else { - None + // Infer the stride from the matrix type + let stride = match *resolve_inner!(ctx, target) { + ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? }; args.finish()?; - let store = function.name.contains("Store"); - let row_major = function.name.ends_with("T"); - let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::CooperativeLoadStore { diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 55d381b911..854c7d5719 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2216,7 +2216,7 @@ pub enum Statement { store: bool, target: Handle, pointer: Handle, - stride: Option>, + stride: Handle, row_major: bool, }, } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 9daf2314b5..fdd98d64f2 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1162,18 +1162,16 @@ impl FunctionInfo { pointer, stride, row_major: _, - } => { - if let Some(stride) = stride { - let _ = self.add_ref(stride); - } - FunctionUniformity { - result: Uniformity { - non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)), - requirements: UniformityRequirements::COOP_OPS, - }, - exit: ExitFlags::empty(), - } - } + } => FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .add_ref(target) + .or(self.add_ref(pointer)) + .or(self.add_ref(stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + }, }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 15a778778c..1bd33eaf3c 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -848,7 +848,7 @@ impl super::Validator { } => { validate_expr(target)?; validate_expr(pointer)?; - validate_expr_opt(stride)?; + validate_expr(stride)?; Ok(()) } crate::Statement::Break diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index 4a3be8b94e..a95da7bf80 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,4 +1,4 @@ -targets = "IR | SPIRV | METAL" +targets = "IR | SPIRV | METAL | WGSL" god_mode = true [spv] diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 7582580360..7f8fc73568 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -123,52 +123,54 @@ base: 2, index: 4, ), + Literal(U32(8)), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 4, - b: 5, + a: 5, + b: 6, c: 1, ), LocalVariable(1), GlobalVariable(2), AccessIndex( - base: 8, + base: 9, index: 0, ), + Literal(U32(8)), ], named_expressions: {}, body: [ + Emit(( + start: 3, + end: 4, + )), CooperativeLoadStore( store: false, target: 1, pointer: 3, - stride: None, + stride: 4, row_major: false, ), Emit(( - start: 3, - end: 4, - )), - Emit(( - start: 6, - end: 7, + start: 7, + end: 8, )), Store( - pointer: 7, - value: 6, + pointer: 8, + value: 7, ), + Emit(( + start: 10, + end: 11, + )), CooperativeLoadStore( store: true, target: 1, - pointer: 9, - stride: None, + pointer: 10, + stride: 11, row_major: false, ), - Emit(( - start: 9, - end: 10, - )), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 7582580360..7f8fc73568 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -123,52 +123,54 @@ base: 2, index: 4, ), + Literal(U32(8)), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 4, - b: 5, + a: 5, + b: 6, c: 1, ), LocalVariable(1), GlobalVariable(2), AccessIndex( - base: 8, + base: 9, index: 0, ), + Literal(U32(8)), ], named_expressions: {}, body: [ + Emit(( + start: 3, + end: 4, + )), CooperativeLoadStore( store: false, target: 1, pointer: 3, - stride: None, + stride: 4, row_major: false, ), Emit(( - start: 3, - end: 4, - )), - Emit(( - start: 6, - end: 7, + start: 7, + end: 8, )), Store( - pointer: 7, - value: 6, + pointer: 8, + value: 7, ), + Emit(( + start: 10, + end: 11, + )), CooperativeLoadStore( store: true, target: 1, - pointer: 9, - stride: None, + pointer: 10, + stride: 11, row_major: false, ), - Emit(( - start: 9, - end: 10, - )), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl index bed4406760..4e17948e6b 100644 --- a/naga/tests/out/msl/wgsl-cooperative-matrix.msl +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -24,8 +24,8 @@ kernel void main_( metal::simdgroup_float8x8 b = {}; metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; metal::simdgroup_float8x8 d = {}; - metal::simdgroup_load(c, ext[4]); + metal::simdgroup_load(c, ext[4], 8u); d = NagaCooperativeMultiplyAdd(a, b, c); - metal::simdgroup_store(c, ext[0]); + metal::simdgroup_store(c, ext[0], 8u); return; } diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index a3626919a9..56d9e8c7ae 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -71,18 +71,18 @@ OpMemberDecorate %22 0 Offset 0 %28 = OpAccessChain %27 %21 %9 OpBranch %34 %34 = OpLabel +OpLine %3 9 18 OpLine %3 9 5 %37 = OpAccessChain %35 %28 %36 -%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 OpStore %30 %38 -OpLine %3 9 18 OpLine %3 10 13 %39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 OpLine %3 10 5 OpStore %32 %39 +OpLine %3 11 19 OpLine %3 11 5 %40 = OpAccessChain %35 %28 %9 -OpCooperativeMatrixStoreKHR %40 %30 %11 -OpLine %3 11 19 +OpCooperativeMatrixStoreKHR %40 %30 %11 %8 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl new file mode 100644 index 0000000000..2b249bb4d5 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -0,0 +1,13 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat8x8 = coop_mat8x8(); + var d: coop_mat8x8; + +coopLoad((&c), (&ext[4]), 8u) d = coopMultiplyAdd((&a), (&b), (&c)); +coopStore((&c), (&ext[0]), 8u) return; +} From 87360f29d597f4f81f957f75651ec3247e0c209b Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 23:20:20 -0700 Subject: [PATCH 10/12] coop: rewire IR using native variables load/store --- CHANGELOG.md | 2 +- naga/src/back/mod.rs | 10 +++ naga/src/back/msl/mod.rs | 4 +- naga/src/back/msl/writer.rs | 39 +++++----- naga/src/back/spv/block.rs | 17 +++- naga/src/back/spv/writer.rs | 6 +- naga/src/back/wgsl/writer.rs | 11 ++- naga/src/front/wgsl/lower/mod.rs | 34 ++++---- naga/src/proc/type_methods.rs | 11 --- naga/src/proc/typifier.rs | 18 ++--- naga/src/valid/analyzer.rs | 29 ++++--- naga/src/valid/expression.rs | 40 +++------- naga/src/valid/function.rs | 2 +- naga/tests/in/wgsl/cooperative-matrix.toml | 3 + naga/tests/in/wgsl/cooperative-matrix.wgsl | 3 +- .../analysis/wgsl-cooperative-matrix.info.ron | 78 ------------------- .../ir/wgsl-cooperative-matrix.compact.ron | 72 +++++++++++++---- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 72 +++++++++++++---- .../tests/out/msl/wgsl-cooperative-matrix.msl | 21 +++-- .../out/spv/wgsl-cooperative-matrix.spvasm | 35 ++++++--- .../out/wgsl/wgsl-cooperative-matrix.wgsl | 13 +++- 21 files changed, 282 insertions(+), 238 deletions(-) delete mode 100644 naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron diff --git a/CHANGELOG.md b/CHANGELOG.md index 69759b8e2e..13e9ef04e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -166,7 +166,7 @@ By @cwfitzgerald in [#8162](https://github.com/gfx-rs/wgpu/pull/8162). - Added support for external textures based on WebGPU's [`GPUExternalTexture`](https://www.w3.org/TR/webgpu/#gpuexternaltexture). These allow shaders to transparently operate on potentially multiplanar source texture data in either RGB or YCbCr formats via WGSL's `texture_external` type. This is gated behind the `Features::EXTERNAL_TEXTURE` feature, which is currently only supported on DX12. By @jamienicol in [#4386](https://github.com/gfx-rs/wgpu/issues/4386). -- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). +- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). ### Changes diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 0fe8e9274f..092f9d1cd1 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } } +impl crate::TypeInner { + /// Returns true if a variable of this type is a handle. + pub const fn is_handle(&self) -> bool { + match *self { + Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true, + _ => false, + } + } +} + impl crate::Statement { /// Returns true if the statement directly terminates the current block. /// diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 64b1280a1b..dfeba8f896 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -228,8 +228,10 @@ pub enum Error { UnsupportedArrayOf(String), #[error("array of type '{0:?}' is not supported")] UnsupportedArrayOfType(Handle), - #[error("ray tracing is not supported prior to MSL 2.3")] + #[error("ray tracing is not supported prior to MSL 2.4")] UnsupportedRayTracing, + #[error("cooperative matrix is not supported prior to MSL 2.3")] + UnsupportedCooperativeMatrix, #[error("overrides should not be present at this stage")] Override, #[error("bitcasting to {0:?} is not supported")] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 9e520c4c13..e095b729bf 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -236,6 +236,7 @@ impl Display for TypeContext<'_> { rows, scalar, } => put_numeric_type(out, scalar, &[rows, columns]), + // Requires Metal-2.3 crate::TypeInner::CooperativeMatrix { columns, rows, @@ -244,8 +245,7 @@ impl Display for TypeContext<'_> { } => { write!( out, - "{}::simdgroup_{}{}x{}", - NAMESPACE, + "{NAMESPACE}::simdgroup_{}{}x{}", scalar.to_msl_name(), columns as u32, rows as u32, @@ -485,6 +485,7 @@ enum WrappedFunction { class: crate::ImageClass, }, CooperativeMultiplyAdd { + space: crate::AddressSpace, columns: crate::CooperativeSize, rows: crate::CooperativeSize, intermediate: crate::CooperativeSize, @@ -2842,6 +2843,9 @@ impl Writer { write!(self.out, "}}")?; } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + if context.lang_version < (2, 3) { + return Err(Error::UnsupportedCooperativeMatrix); + } write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; self.put_expression(a, context, true)?; write!(self.out, ", ")?; @@ -4239,10 +4243,14 @@ impl Writer { row_major, } => { let op_str = if store { "store" } else { "load" }; - write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?; + write!(self.out, "{level}simdgroup_{op_str}(")?; self.put_expression(target, &context.expression, true)?; - write!(self.out, ", ")?; - self.put_expression(pointer, &context.expression, true)?; + write!(self.out, ", &")?; + self.put_access_chain( + pointer, + context.expression.policies.index, + &context.expression, + )?; write!(self.out, ", ")?; self.put_expression(stride, &context.expression, true)?; if row_major { @@ -6312,6 +6320,7 @@ template &mut self, module: &crate::Module, func_ctx: &back::FunctionCtx, + space: crate::AddressSpace, a: Handle, b: Handle, ) -> BackendResult { @@ -6329,6 +6338,7 @@ template _ => unreachable!(), }; let wrapped = WrappedFunction::CooperativeMultiplyAdd { + space, columns: b_c, rows: a_r, intermediate: a_c, @@ -6337,15 +6347,11 @@ template if !self.wrapped_functions.insert(wrapped) { return Ok(()); } - let scalar_name = match scalar.width { - 2 => "half", - 4 => "float", - 8 => "double", - _ => unreachable!(), - }; + let space_name = space.to_msl_name().unwrap_or_default(); + let scalar_name = scalar.to_msl_name(); writeln!( self.out, - "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, )?; let l1 = back::Level(1); @@ -6354,10 +6360,7 @@ template "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", b_c as u32, a_r as u32 )?; - writeln!( - self.out, - "{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);" - )?; + writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?; writeln!(self.out, "{l1}return d;")?; writeln!(self.out, "}}")?; writeln!(self.out)?; @@ -6439,7 +6442,8 @@ template self.write_wrapped_image_query(module, func_ctx, image, query)?; } crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { - self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?; + let space = crate::AddressSpace::Private; + self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?; } _ => {} } @@ -6632,7 +6636,6 @@ template names: &self.names, handle, usage: fun_info[handle], - reference: true, }; let separator = diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0920a86f39..c13604605b 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3726,9 +3726,20 @@ impl BlockContext<'_> { layout_id, self.cached[stride], )); - block - .body - .push(Instruction::store(self.cached[target], id, None)); + match self.write_access_chain( + target, + &mut block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { + pointer_id: target_id, + } => { + block.body.push(Instruction::store(target_id, id, None)); + } + ExpressionPointer::Conditional { .. } => { + unimplemented!() + } + }; } } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index ad08818639..88c3a1629c 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -971,14 +971,13 @@ impl Writer { } } - // Handle globals are pre-emitted and should be loaded automatically. - // - // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. match ir_module.types[var.ty].inner { + // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. crate::TypeInner::BindingArray { .. } => { gv.access_id = gv.var_id; } _ => { + // Handle globals are pre-emitted and should be loaded automatically. if var.space == crate::AddressSpace::Handle { let var_type_id = self.get_handle_type_id(var.ty); let id = self.id_gen.next(); @@ -1064,6 +1063,7 @@ impl Writer { } }), ); + context .function .variables diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 95fc3af6c0..0e5aef2f2d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -993,13 +993,13 @@ impl Writer { } => { let op_str = if store { "Store" } else { "Load" }; let suffix = if row_major { "T" } else { "" }; - write!(self.out, "coop{op_str}{suffix}(")?; + write!(self.out, "{level}coop{op_str}{suffix}(")?; self.write_expr(module, target, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, stride, func_ctx)?; - write!(self.out, ")")? + writeln!(self.out, ");")? } } @@ -1118,6 +1118,13 @@ impl Writer { // If the plain form of the expression is not what we need, emit the // operator necessary to correct that. let plain = self.plain_form_indirection(expr, module, func_ctx); + log::trace!( + "expression {:?}={:?} is {:?}, expected {:?}", + expr, + func_ctx.expressions[expr], + plain, + requested, + ); match (requested, plain) { (Indirection::Ordinary, Indirection::Reference) => { write!(self.out, "(&")?; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 7b2c6866cf..eefa1eb90b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -524,6 +524,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result<'source, Handle> { let mut eval = self.as_const_evaluator(); + log::debug!("appending {expr:?}"); eval.try_eval_and_append(expr, span) .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) } @@ -846,6 +847,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle { self.as_global().ensure_type_exists(None, inner) } + + fn _get_runtime_expression(&self, expr: Handle) -> &ir::Expression { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr], + ExpressionContextType::Constant(_) | ExpressionContextType::Override => { + unreachable!() + } + } + } } struct ArgumentContext<'ctx, 'source> { @@ -955,6 +965,13 @@ impl Typed { Self::Plain(expr) => Typed::Plain(f(expr)?), }) } + + fn ref_or(self, error: E) -> core::result::Result { + match self { + Self::Reference(v) => Ok(v), + Self::Plain(_) => Err(error), + } + } } /// A single vector component or swizzle. @@ -1677,12 +1694,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .as_expression(block, &mut emitter) .interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?; block.extend(emitter.finish(&ctx.function.expressions)); - let typed = if ctx.module.types[ty].inner.is_handle() { - Typed::Plain(handle) - } else { - Typed::Reference(handle) - }; - ctx.local_table.insert(v.handle, Declared::Runtime(typed)); + ctx.local_table + .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); match initializer { Some(initializer) => ir::Statement::Store { @@ -1977,12 +1990,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value_span = ctx.ast_expressions.get_span(value); let target = self .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; - let target_handle = match target { - Typed::Reference(handle) => handle, - Typed::Plain(_) => { - return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))) - } - }; + let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?; let mut ectx = ctx.as_expression(block, &mut emitter); let scalar = match *resolve_inner!(ectx, target_handle) { @@ -2139,10 +2147,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); let v = &ctx.module.global_variables[handle]; - let force_value = ctx.module.types[v.ty].inner.is_handle(); match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), - _ if force_value => Typed::Plain(expr), _ => Typed::Reference(expr), } } diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 136ea29218..fe3eb4b626 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -191,17 +191,6 @@ impl crate::TypeInner { } } - /// Returns true if a variable of this type is a handle. - pub const fn is_handle(&self) -> bool { - match *self { - Self::Image { .. } - | Self::Sampler { .. } - | Self::AccelerationStructure { .. } - | Self::CooperativeMatrix { .. } => true, - _ => false, - } - } - /// Attempt to calculate the size of this type. Returns `None` if the size /// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`]. pub fn try_size(&self, gctx: super::GlobalCtx) -> Option { diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 8e323d7724..49ce5de0b6 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -454,8 +454,7 @@ impl<'a> ResolveContext<'a> { } crate::Expression::GlobalVariable(h) => { let var = &self.global_vars[h]; - let ty = &types[var.ty].inner; - if var.space == crate::AddressSpace::Handle || ty.is_handle() { + if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { TypeResolution::Value(Ti::Pointer { @@ -466,15 +465,10 @@ impl<'a> ResolveContext<'a> { } crate::Expression::LocalVariable(h) => { let var = &self.local_vars[h]; - let ty = &types[var.ty].inner; - if ty.is_handle() { - TypeResolution::Handle(var.ty) - } else { - TypeResolution::Value(Ti::Pointer { - base: var.ty, - space: crate::AddressSpace::Function, - }) - } + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) } crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { Ti::Pointer { base, space: _ } => { @@ -493,7 +487,7 @@ impl<'a> ResolveContext<'a> { None => Ti::Scalar(scalar), }), ref other => { - log::error!("Pointer type {other:?}"); + log::error!("Pointer {pointer:?} type {other:?}"); return Err(ResolveError::InvalidPointer(pointer)); } }, diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index fdd98d64f2..fb08f5688e 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1157,21 +1157,28 @@ impl FunctionInfo { FunctionUniformity::new() } S::CooperativeLoadStore { - store: _, + store, target, pointer, stride, row_major: _, - } => FunctionUniformity { - result: Uniformity { - non_uniform_result: self - .add_ref(target) - .or(self.add_ref(pointer)) - .or(self.add_ref(stride)), - requirements: UniformityRequirements::COOP_OPS, - }, - exit: ExitFlags::empty(), - }, + } => { + let access = if store { + GlobalUse::WRITE + } else { + GlobalUse::READ + }; + FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .add_ref(target) + .or(self.add_ref_impl(pointer, access)) + .or(self.add_ref(stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + } + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 466ca26b60..ad52c210b4 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1256,34 +1256,18 @@ impl super::Validator { }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, E::CooperativeMultiplyAdd { a, b, c } => { - match resolver[a] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::A, - .. - } => {} - ref other => { - log::error!("A operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(a)); - } - } - match resolver[b] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::B, - .. - } => {} - ref other => { - log::error!("B operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(b)); - } - } - match resolver[c] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::C, - .. - } => {} - ref other => { - log::error!("C operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(c)); + let roles = [ + crate::CooperativeRole::A, + crate::CooperativeRole::B, + crate::CooperativeRole::C, + ]; + for (operand, expected_role) in [a, b, c].into_iter().zip(roles) { + match resolver[operand] { + Ti::CooperativeMatrix { role, .. } if role == expected_role => {} + ref other => { + log::error!("{expected_role:?} operand type: {other:?}"); + return Err(ExpressionError::InvalidCooperativeOperand(a)); + } } } ShaderStages::COMPUTE diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index e1263e88fb..615c524eca 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1078,7 +1078,7 @@ impl super::Validator { } else if let Some(tr) = pointer_base_tr { context.compare_types(value_tr, &tr) } else { - value_ty.is_handle() + false }; if !good { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index a95da7bf80..7d20269efc 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -4,3 +4,6 @@ god_mode = true [spv] debug = true version = [1, 4] + +[msl] +lang_version = [2, 3] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index e65fe0d589..5d7b1bfb26 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -8,5 +8,6 @@ fn main() { var c = coop_mat8x8(); coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext[0]); + coopStore(d, &ext[0]); + c = d; } diff --git a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron deleted file mode 100644 index f806c3f3dd..0000000000 --- a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron +++ /dev/null @@ -1,78 +0,0 @@ -( - type_flags: [ - ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), - ], - functions: [], - entry_points: [ - ( - flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), - uniformity: ( - non_uniform_result: None, - requirements: (""), - ), - may_kill: false, - sampling_set: [], - global_uses: [ - ("READ"), - ], - expressions: [ - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 1, - assignable_global: Some(0), - ty: Value(Pointer( - base: 0, - space: Private, - )), - ), - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 1, - assignable_global: None, - ty: Handle(0), - ), - ( - uniformity: ( - non_uniform_result: Some(2), - requirements: (""), - ), - ref_count: 1, - assignable_global: Some(0), - ty: Value(Pointer( - base: 0, - space: Private, - )), - ), - ( - uniformity: ( - non_uniform_result: Some(2), - requirements: (""), - ), - ref_count: 1, - assignable_global: None, - ty: Handle(0), - ), - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 0, - assignable_global: None, - ty: Handle(0), - ), - ], - sampling: [], - dual_source_blending: false, - diagnostic_filter_leaf: None, - ), - ], - const_expression_types: [], -) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 7f8fc73568..051e07fd1e 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -118,59 +118,97 @@ expressions: [ ZeroValue(4), LocalVariable(0), + Load( + pointer: 1, + ), GlobalVariable(2), AccessIndex( - base: 2, + base: 3, index: 4, ), Literal(U32(8)), GlobalVariable(0), + Load( + pointer: 6, + ), GlobalVariable(1), + Load( + pointer: 8, + ), + Load( + pointer: 1, + ), CooperativeMultiplyAdd( - a: 5, - b: 6, - c: 1, + a: 7, + b: 9, + c: 10, ), LocalVariable(1), + Load( + pointer: 12, + ), GlobalVariable(2), AccessIndex( - base: 9, + base: 14, index: 0, ), Literal(U32(8)), + Load( + pointer: 12, + ), ], named_expressions: {}, body: [ Emit(( - start: 3, - end: 4, + start: 2, + end: 3, + )), + Emit(( + start: 4, + end: 5, )), CooperativeLoadStore( store: false, - target: 1, - pointer: 3, - stride: 4, + target: 2, + pointer: 4, + stride: 5, row_major: false, ), Emit(( start: 7, end: 8, )), + Emit(( + start: 9, + end: 12, + )), Store( - pointer: 8, - value: 7, + pointer: 12, + value: 11, ), Emit(( - start: 10, - end: 11, + start: 13, + end: 14, + )), + Emit(( + start: 15, + end: 16, )), CooperativeLoadStore( store: true, - target: 1, - pointer: 10, - stride: 11, + target: 13, + pointer: 15, + stride: 16, row_major: false, ), + Emit(( + start: 17, + end: 18, + )), + Store( + pointer: 1, + value: 17, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 7f8fc73568..051e07fd1e 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -118,59 +118,97 @@ expressions: [ ZeroValue(4), LocalVariable(0), + Load( + pointer: 1, + ), GlobalVariable(2), AccessIndex( - base: 2, + base: 3, index: 4, ), Literal(U32(8)), GlobalVariable(0), + Load( + pointer: 6, + ), GlobalVariable(1), + Load( + pointer: 8, + ), + Load( + pointer: 1, + ), CooperativeMultiplyAdd( - a: 5, - b: 6, - c: 1, + a: 7, + b: 9, + c: 10, ), LocalVariable(1), + Load( + pointer: 12, + ), GlobalVariable(2), AccessIndex( - base: 9, + base: 14, index: 0, ), Literal(U32(8)), + Load( + pointer: 12, + ), ], named_expressions: {}, body: [ Emit(( - start: 3, - end: 4, + start: 2, + end: 3, + )), + Emit(( + start: 4, + end: 5, )), CooperativeLoadStore( store: false, - target: 1, - pointer: 3, - stride: 4, + target: 2, + pointer: 4, + stride: 5, row_major: false, ), Emit(( start: 7, end: 8, )), + Emit(( + start: 9, + end: 12, + )), Store( - pointer: 8, - value: 7, + pointer: 12, + value: 11, ), Emit(( - start: 10, - end: 11, + start: 13, + end: 14, + )), + Emit(( + start: 15, + end: 16, )), CooperativeLoadStore( store: true, - target: 1, - pointer: 10, - stride: 11, + target: 13, + pointer: 15, + stride: 16, row_major: false, ), + Emit(( + start: 17, + end: 18, + )), + Store( + pointer: 1, + value: 17, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl index 4e17948e6b..7f7562ce9f 100644 --- a/naga/tests/out/msl/wgsl-cooperative-matrix.msl +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -1,4 +1,4 @@ -// language: metal1.0 +// language: metal2.3 #include #include @@ -9,23 +9,30 @@ struct _mslBufferSizes { }; typedef float type_3[1]; -metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const metal::simdgroup_float8x8& a, const metal::simdgroup_float8x8& b, const metal::simdgroup_float8x8& c) { +metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const thread metal::simdgroup_float8x8& a, const thread metal::simdgroup_float8x8& b, const thread metal::simdgroup_float8x8& c) { metal::simdgroup_float8x8 d; - metal::simdgroup_multiply_accumulate(d,a,b,c); + simdgroup_multiply_accumulate(d,a,b,c); return d; } kernel void main_( - device type_3 const& ext [[user(fake0)]] + device type_3& ext [[user(fake0)]] , constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { metal::simdgroup_float8x8 a = {}; metal::simdgroup_float8x8 b = {}; metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; metal::simdgroup_float8x8 d = {}; - metal::simdgroup_load(c, ext[4], 8u); - d = NagaCooperativeMultiplyAdd(a, b, c); - metal::simdgroup_store(c, ext[0], 8u); + metal::simdgroup_float8x8 _e2 = c; + simdgroup_load(_e2, &ext[4], 8u); + metal::simdgroup_float8x8 _e7 = a; + metal::simdgroup_float8x8 _e9 = b; + metal::simdgroup_float8x8 _e10 = c; + d = NagaCooperativeMultiplyAdd(_e7, _e9, _e10); + metal::simdgroup_float8x8 _e13 = d; + simdgroup_store(_e13, &ext[0], 8u); + metal::simdgroup_float8x8 _e17 = d; + c = _e17; return; } diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 56d9e8c7ae..d02c1a3cb5 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 41 +; Bound: 47 OpCapability Shader OpCapability CooperativeMatrixKHR OpCapability VulkanMemoryModel @@ -22,7 +22,8 @@ fn main() { var c = coop_mat8x8(); coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext[0]); + coopStore(d, &ext[0]); + c = d; } " OpName %15 "a" @@ -62,8 +63,8 @@ OpMemberDecorate %22 0 Offset 0 %29 = OpConstantNull %13 %31 = OpTypePointer Function %13 %33 = OpConstantNull %13 -%35 = OpTypePointer StorageBuffer %4 -%36 = OpConstant %7 4 +%36 = OpTypePointer StorageBuffer %4 +%37 = OpConstant %7 4 %25 = OpFunction %2 None %26 %24 = OpLabel %30 = OpVariable %31 Function %29 @@ -71,18 +72,30 @@ OpMemberDecorate %22 0 Offset 0 %28 = OpAccessChain %27 %21 %9 OpBranch %34 %34 = OpLabel +OpLine %3 1 1 +%35 = OpLoad %13 %30 OpLine %3 9 18 OpLine %3 9 5 -%37 = OpAccessChain %35 %28 %36 -%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 -OpStore %30 %38 +%38 = OpAccessChain %36 %28 %37 +%39 = OpCooperativeMatrixLoadKHR %13 %38 %11 %8 +OpStore %35 %39 +OpLine %3 10 29 +%40 = OpLoad %5 %15 OpLine %3 10 13 -%39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 +%41 = OpLoad %10 %18 +%42 = OpLoad %13 %30 +%43 = OpCooperativeMatrixMulAddKHR %13 %40 %41 %42 OpLine %3 10 5 -OpStore %32 %39 +OpStore %32 %43 +OpLine %3 1 1 +%44 = OpLoad %13 %32 OpLine %3 11 19 OpLine %3 11 5 -%40 = OpAccessChain %35 %28 %9 -OpCooperativeMatrixStoreKHR %40 %30 %11 %8 +%45 = OpAccessChain %36 %28 %9 +OpCooperativeMatrixStoreKHR %45 %44 %11 %8 +OpLine %3 1 1 +%46 = OpLoad %13 %32 +OpLine %3 12 5 +OpStore %30 %46 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl index 2b249bb4d5..af7a6195b6 100644 --- a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -8,6 +8,15 @@ fn main() { var c: coop_mat8x8 = coop_mat8x8(); var d: coop_mat8x8; -coopLoad((&c), (&ext[4]), 8u) d = coopMultiplyAdd((&a), (&b), (&c)); -coopStore((&c), (&ext[0]), 8u) return; + let _e2 = c; + coopLoad(_e2, (&ext[4]), 8u); + let _e7 = a; + let _e9 = b; + let _e10 = c; + d = coopMultiplyAdd(_e7, _e9, _e10); + let _e13 = d; + coopStore(_e13, (&ext[0]), 8u); + let _e17 = d; + c = _e17; + return; } From 07be9e9a1a2019226745901e81363c19e9348026 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 28 Sep 2025 00:37:15 -0700 Subject: [PATCH 11/12] coop: make cooperativeLoad to be an expression --- naga/src/back/dot/mod.rs | 24 ++--- naga/src/back/glsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 3 +- naga/src/back/msl/writer.rs | 98 ++++++++++++++++--- naga/src/back/pipeline_constants.rs | 15 +-- naga/src/back/spv/block.rs | 83 ++++++++-------- naga/src/back/wgsl/writer.rs | 25 ++--- naga/src/compact/expressions.rs | 8 ++ naga/src/compact/statements.rs | 23 ++--- naga/src/front/spv/mod.rs | 2 +- naga/src/front/wgsl/lower/mod.rs | 48 +++++++-- naga/src/ir/mod.rs | 26 +++-- naga/src/proc/constant_evaluator.rs | 2 +- naga/src/proc/terminator.rs | 2 +- naga/src/proc/typifier.rs | 18 ++++ naga/src/valid/analyzer.rs | 37 +++---- naga/src/valid/expression.rs | 10 ++ naga/src/valid/function.rs | 45 +++------ naga/src/valid/handles.rs | 15 ++- naga/tests/in/wgsl/cooperative-matrix.wgsl | 3 +- .../ir/wgsl-cooperative-matrix.compact.ron | 94 +++++++++--------- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 94 +++++++++--------- .../tests/out/msl/wgsl-cooperative-matrix.msl | 27 ++--- .../out/spv/wgsl-cooperative-matrix.spvasm | 60 ++++++------ .../out/wgsl/wgsl-cooperative-matrix.wgsl | 21 ++-- 25 files changed, 462 insertions(+), 324 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 358c4c35c4..7be93c9003 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -403,20 +403,14 @@ impl StatementGraph { }, } } - S::CooperativeLoadStore { - store, - target, - pointer, - stride, - row_major: _, - } => { + S::CooperativeStore { target, data } => { self.dependencies.push((id, target, "target")); - self.dependencies.push((id, pointer, "pointer")); - self.dependencies.push((id, stride, "stride")); - if store { - "Store" + self.dependencies.push((id, data.pointer, "pointer")); + self.dependencies.push((id, data.stride, "stride")); + if data.row_major { + "CoopStoreT" } else { - "Load" + "CoopStore" } } }; @@ -758,6 +752,12 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } + E::CooperativeLoad { ref data, .. } => { + edges.insert("pointer", data.pointer); + edges.insert("stride", data.stride); + let suffix = if data.row_major { "T " } else { "" }; + (format!("coopLoad{suffix}").into(), 4) + } E::CooperativeMultiplyAdd { a, b, c } => { edges.insert("a", a); edges.insert("b", b); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 8a015fccd9..3c3b70289a 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2805,7 +2805,7 @@ impl<'a, W: Write> Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::CooperativeLoadStore { .. } => unimplemented!(), + Statement::CooperativeStore { .. } => unimplemented!(), } Ok(()) @@ -4343,6 +4343,7 @@ impl<'a, W: Write> Writer<'a, W> { // not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => unreachable!(), } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index edf76e2c20..33ad58e17b 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2747,7 +2747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::CooperativeLoadStore { .. } => unimplemented!(), + Statement::CooperativeStore { .. } => unimplemented!(), } Ok(()) @@ -4277,6 +4277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } // Not supported yet Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => { unreachable!() } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index e095b729bf..952e52eda6 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp /// allowing them to be conveniently passed to user-defined or wrapper /// functions. The struct is declared in [`Writer::write_type_defs`]. pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper"; +pub(crate) const COOPERATIVE_LOAD_FUNCTION: &str = "NagaCooperativeLoad"; pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. @@ -484,6 +485,12 @@ enum WrappedFunction { ImageQuerySize { class: crate::ImageClass, }, + CooperativeLoad { + space: crate::AddressSpace, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + scalar: crate::Scalar, + }, CooperativeMultiplyAdd { space: crate::AddressSpace, columns: crate::CooperativeSize, @@ -2842,6 +2849,17 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::CooperativeLoad { ref data, .. } => { + if context.lang_version < (2, 3) { + return Err(Error::UnsupportedCooperativeMatrix); + } + write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?; + write!(self.out, "&")?; + self.put_access_chain(data.pointer, context.policies.index, context)?; + write!(self.out, ", ")?; + self.put_expression(data.stride, context, true)?; + write!(self.out, ", {})", data.row_major)?; + } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { if context.lang_version < (2, 3) { return Err(Error::UnsupportedCooperativeMatrix); @@ -4235,25 +4253,18 @@ impl Writer { } writeln!(self.out, ");")?; } - crate::Statement::CooperativeLoadStore { - store, - target, - pointer, - stride, - row_major, - } => { - let op_str = if store { "store" } else { "load" }; - write!(self.out, "{level}simdgroup_{op_str}(")?; + crate::Statement::CooperativeStore { target, ref data } => { + write!(self.out, "{level}simdgroup_store(")?; self.put_expression(target, &context.expression, true)?; write!(self.out, ", &")?; self.put_access_chain( - pointer, + data.pointer, context.expression.policies.index, &context.expression, )?; write!(self.out, ", ")?; - self.put_expression(stride, &context.expression, true)?; - if row_major { + self.put_expression(data.stride, &context.expression, true)?; + if data.row_major { let matrix_origin = "0"; let transpose = true; write!(self.out, ", {matrix_origin}, {transpose}")?; @@ -6316,6 +6327,55 @@ template Ok(()) } + fn write_wrapped_cooperative_load( + &mut self, + module: &crate::Module, + func_ctx: &back::FunctionCtx, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + pointer: Handle, + ) -> BackendResult { + let ptr_ty = func_ctx.resolve_type(pointer, &module.types); + let space = ptr_ty.pointer_space().unwrap(); + let scalar = ptr_ty + .pointer_base_type() + .unwrap() + .inner_with(&module.types) + .scalar() + .unwrap(); + let wrapped = WrappedFunction::CooperativeLoad { + space, + columns, + rows, + scalar, + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + let space_name = space.to_msl_name().unwrap_or_default(); + let scalar_name = scalar.to_msl_name(); + writeln!( + self.out, + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{", + columns as u32, rows as u32, + )?; + let l1 = back::Level(1); + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;", + columns as u32, rows as u32 + )?; + let matrix_origin = "0"; + writeln!( + self.out, + "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);" + )?; + writeln!(self.out, "{l1}return m;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + fn write_wrapped_cooperative_multiply_add( &mut self, module: &crate::Module, @@ -6441,6 +6501,20 @@ template crate::Expression::ImageQuery { image, query } => { self.write_wrapped_image_query(module, func_ctx, image, query)?; } + crate::Expression::CooperativeLoad { + columns, + rows, + role: _, + ref data, + } => { + self.write_wrapped_cooperative_load( + module, + func_ctx, + columns, + rows, + data.pointer, + )?; + } crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { let space = crate::AddressSpace::Private; self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 5a9fd9558d..6a5ce44289 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,6 +633,10 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } + Expression::CooperativeLoad { ref mut data, .. } => { + adjust(&mut data.pointer); + adjust(&mut data.stride); + } Expression::CooperativeMultiplyAdd { ref mut a, ref mut b, @@ -844,16 +848,13 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } - Statement::CooperativeLoadStore { - store: _, + Statement::CooperativeStore { ref mut target, - ref mut pointer, - ref mut stride, - row_major: _, + ref mut data, } => { adjust(target); - adjust(pointer); - adjust(stride); + adjust(&mut data.pointer); + adjust(&mut data.stride); } Statement::Break | Statement::Continue diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index c13604605b..14496957c1 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1805,6 +1805,39 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } + crate::Expression::CooperativeLoad { ref data, .. } => { + self.writer.require_any( + "CooperativeMatrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + let pointer_id = match self.write_access_chain( + data.pointer, + block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Copperative load/store out-of-bounds handling", + )); + } + }; + let layout = if data.row_major { + spirv::CooperativeMatrixLayout::RowMajorKHR + } else { + spirv::CooperativeMatrixLayout::ColumnMajorKHR + }; + let layout_id = self.get_index_constant(layout as u32); + let id = self.gen_id(); + block.body.push(Instruction::coop_load( + result_type_id, + id, + pointer_id, + layout_id, + self.cached[data.stride], + )); + id + } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { self.writer.require_any( "CooperativeMatrix", @@ -3684,15 +3717,9 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } - Statement::CooperativeLoadStore { - store, - target, - pointer, - stride, - row_major, - } => { + Statement::CooperativeStore { target, ref data } => { let pointer_id = match self.write_access_chain( - pointer, + data.pointer, &mut block, AccessTypeAdjustment::None, )? { @@ -3703,44 +3730,18 @@ impl BlockContext<'_> { )); } }; - let layout = if row_major { + let layout = if data.row_major { spirv::CooperativeMatrixLayout::RowMajorKHR } else { spirv::CooperativeMatrixLayout::ColumnMajorKHR }; let layout_id = self.get_index_constant(layout as u32); - if store { - block.body.push(Instruction::coop_store( - self.cached[target], - pointer_id, - layout_id, - self.cached[stride], - )); - } else { - let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); - let id = self.gen_id(); - block.body.push(Instruction::coop_load( - result_type_id, - id, - pointer_id, - layout_id, - self.cached[stride], - )); - match self.write_access_chain( - target, - &mut block, - AccessTypeAdjustment::None, - )? { - ExpressionPointer::Ready { - pointer_id: target_id, - } => { - block.body.push(Instruction::store(target_id, id, None)); - } - ExpressionPointer::Conditional { .. } => { - unimplemented!() - } - }; - } + block.body.push(Instruction::coop_store( + self.cached[target], + pointer_id, + layout_id, + self.cached[data.stride], + )); } } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 0e5aef2f2d..9a116d7464 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -984,21 +984,14 @@ impl Writer { } writeln!(self.out, ");")?; } - Statement::CooperativeLoadStore { - store, - target, - pointer, - stride, - row_major, - } => { - let op_str = if store { "Store" } else { "Load" }; - let suffix = if row_major { "T" } else { "" }; - write!(self.out, "{level}coop{op_str}{suffix}(")?; + Statement::CooperativeStore { target, ref data } => { + let suffix = if data.row_major { "T" } else { "" }; + write!(self.out, "{level}coopStore{suffix}(")?; self.write_expr(module, target, func_ctx)?; write!(self.out, ", ")?; - self.write_expr(module, pointer, func_ctx)?; + self.write_expr(module, data.pointer, func_ctx)?; write!(self.out, ", ")?; - self.write_expr(module, stride, func_ctx)?; + self.write_expr(module, data.stride, func_ctx)?; writeln!(self.out, ");")? } } @@ -1719,6 +1712,14 @@ impl Writer { | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} + Expression::CooperativeLoad { ref data, .. } => { + let suffix = if data.row_major { "T" } else { "" }; + write!(self.out, "coopLoad{suffix}(")?; + self.write_expr(module, data.pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, data.stride, func_ctx)?; + write!(self.out, ")")?; + } Expression::CooperativeMultiplyAdd { a, b, c } => { write!(self.out, "coopMultiplyAdd(")?; self.write_expr(module, a, func_ctx)?; diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 2b2117cc16..021401d00e 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,6 +253,10 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } + Ex::CooperativeLoad { ref data, .. } => { + self.expressions_used.insert(data.pointer); + self.expressions_used.insert(data.stride); + } Ex::CooperativeMultiplyAdd { a, b, c } => { self.expressions_used.insert(a); self.expressions_used.insert(b); @@ -424,6 +428,10 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), + Ex::CooperativeLoad { ref mut data, .. } => { + adjust(&mut data.pointer); + adjust(&mut data.stride); + } Ex::CooperativeMultiplyAdd { ref mut a, ref mut b, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 5c36a40274..af72cb872a 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -152,16 +152,10 @@ impl FunctionTracer<'_> { self.expressions_used.insert(argument); self.expressions_used.insert(result); } - St::CooperativeLoadStore { - store: _, - target, - pointer, - stride, - row_major: _, - } => { + St::CooperativeStore { target, ref data } => { self.expressions_used.insert(target); - self.expressions_used.insert(pointer); - self.expressions_used.insert(stride); + self.expressions_used.insert(data.pointer); + self.expressions_used.insert(data.stride); } // Trivial statements. @@ -382,16 +376,13 @@ impl FunctionMap { adjust(argument); adjust(result); } - St::CooperativeLoadStore { - store: _, + St::CooperativeStore { ref mut target, - ref mut pointer, - ref mut stride, - row_major: _, + ref mut data, } => { adjust(target); - adjust(pointer); - adjust(stride); + adjust(&mut data.pointer); + adjust(&mut data.stride); } // Trivial statements. diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index e5ae2ec2e8..aa29bbffae 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4654,7 +4654,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::CooperativeLoadStore { .. } => unreachable!(), + S::CooperativeStore { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index eefa1eb90b..fd005a3264 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3141,8 +3141,41 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => { - let store = function.name.contains("Store"); + "coopLoad" | "coopLoadT" => { + let row_major = function.name.ends_with("T"); + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.expression(args.next()?, ctx)?; + //TODO: read from generic argument + let columns = crate::CooperativeSize::Eight; + let rows = crate::CooperativeSize::Eight; + let stride = if args.total_args > 1 { + self.expression(args.next()?, ctx)? + } else { + // Infer the stride from the matrix type + let stride = if row_major { + columns as u32 + } else { + rows as u32 + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? + }; + args.finish()?; + + crate::Expression::CooperativeLoad { + columns, + rows, + role: crate::CooperativeRole::C, //TODO + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, + } + } + "coopStore" | "coopStoreT" => { let row_major = function.name.ends_with("T"); let mut args = ctx.prepare_args(arguments, 2, span); @@ -3171,12 +3204,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::CooperativeLoadStore { - store, + crate::Statement::CooperativeStore { target, - pointer, - stride, - row_major, + data: crate::CooperativeData { + pointer, + stride, + row_major, + }, }, span, ); diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 854c7d5719..e1a90619b7 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -1421,6 +1421,16 @@ bitflags::bitflags! { } } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct CooperativeData { + pub pointer: Handle, + pub stride: Handle, + pub row_major: bool, +} + /// An expression that can be evaluated to obtain a value. /// /// This is a Single Static Assignment (SSA) scheme similar to SPIR-V. @@ -1764,6 +1774,13 @@ pub enum Expression { /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, + /// Load a cooperative primitive from memory. + CooperativeLoad { + columns: CooperativeSize, + rows: CooperativeSize, + role: CooperativeRole, + data: CooperativeData, + }, /// Compute `a * b + c` CooperativeMultiplyAdd { a: Handle, @@ -2211,13 +2228,10 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, - /// Load from or store into a cooperative primitive. - CooperativeLoadStore { - store: bool, + /// Store a cooperative primitive into memory. + CooperativeStore { target: Handle, - pointer: Handle, - stride: Handle, - row_major: bool, + data: CooperativeData, }, } diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index b0508193bf..c80789b2df 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -973,7 +973,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } - Expression::CooperativeMultiplyAdd { .. } => { + Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => { Err(ConstantEvaluatorError::CooperativeOperation) } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index a670694f23..6ffd815930 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -44,7 +44,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupGather { .. } | S::ControlBarrier(_) | S::MemoryBarrier(_) - | S::CooperativeLoadStore { .. }), + | S::CooperativeStore { .. }), ) | None => block.push(S::Return { value: None }, Default::default()), } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 49ce5de0b6..1d46980774 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -801,6 +801,24 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), + crate::Expression::CooperativeLoad { + columns, + rows, + role, + ref data, + } => { + let scalar = past(data.pointer)? + .inner_with(types) + .pointer_base_type() + .and_then(|tr| tr.inner_with(types).scalar()) + .ok_or(ResolveError::InvalidPointer(data.pointer))?; + TypeResolution::Value(Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }) + } crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(), }) } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index fb08f5688e..47a43731b8 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -823,6 +823,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::CooperativeLoad { ref data, .. } => Uniformity { + non_uniform_result: self.add_ref(data.pointer).or(self.add_ref(data.stride)), + requirements: UniformityRequirements::COOP_OPS, + }, E::CooperativeMultiplyAdd { a, b, c } => Uniformity { non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), requirements: UniformityRequirements::COOP_OPS, @@ -1156,29 +1160,16 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::CooperativeLoadStore { - store, - target, - pointer, - stride, - row_major: _, - } => { - let access = if store { - GlobalUse::WRITE - } else { - GlobalUse::READ - }; - FunctionUniformity { - result: Uniformity { - non_uniform_result: self - .add_ref(target) - .or(self.add_ref_impl(pointer, access)) - .or(self.add_ref(stride)), - requirements: UniformityRequirements::COOP_OPS, - }, - exit: ExitFlags::empty(), - } - } + S::CooperativeStore { target, ref data } => FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .add_ref(target) + .or(self.add_ref_impl(data.pointer, GlobalUse::WRITE)) + .or(self.add_ref(data.stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + }, }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index ad52c210b4..8f39689642 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1255,6 +1255,16 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, + E::CooperativeLoad { ref data, .. } => { + if resolver[data.pointer] + .pointer_base_type() + .and_then(|tr| tr.inner_with(&module.types).scalar()) + .is_none() + { + return Err(ExpressionError::InvalidPointerType(data.pointer)); + } + ShaderStages::COMPUTE + } E::CooperativeMultiplyAdd { a, b, c } => { let roles = [ crate::CooperativeRole::A, diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 615c524eca..5c45aa4407 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -800,6 +800,7 @@ impl super::Validator { | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } | Ex::RayQueryVertexPositions { .. } + | Ex::CooperativeLoad { .. } | Ex::CooperativeMultiplyAdd { .. } => { self.emit_expression(handle, context)? } @@ -1622,13 +1623,7 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } - S::CooperativeLoadStore { - store, - target, - pointer, - stride: _, - row_major: _, - } => { + S::CooperativeStore { target, ref data } => { stages &= super::ShaderStages::COMPUTE; let target_scalar = @@ -1641,30 +1636,22 @@ impl super::Validator { } }; - let ty_inner = context.resolve_pointer_type(pointer); - //TODO: validate stride - let (pty_scalar, space) = match *ty_inner { - crate::TypeInner::Pointer { base, space } => (base, space), - _ => { - return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)); - } - }; - let space = match context.types[pty_scalar].inner { - crate::TypeInner::Scalar(s) if s == target_scalar => space, - _ => { - return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)); - } - }; + let ptr_ty = context.resolve_pointer_type(data.pointer); + let ptr_scalar = ptr_ty + .pointer_base_type() + .and_then(|tr| tr.inner_with(context.types).scalar()); + if ptr_scalar != Some(target_scalar) { + return Err(FunctionError::InvalidCooperativeDataPointer(data.pointer) + .with_span_handle(data.pointer, context.expressions)); + } - if store && !space.access().contains(crate::StorageAccess::STORE) { - return Err( - FunctionError::InvalidStorePointer(pointer).with_span_static( - context.expressions.get_span(pointer), + let ptr_space = ptr_ty.pointer_space().unwrap_or(AddressSpace::Handle); + if !ptr_space.access().contains(crate::StorageAccess::STORE) { + return Err(FunctionError::InvalidStorePointer(data.pointer) + .with_span_static( + context.expressions.get_span(data.pointer), "writing to this location is not permitted", - ), - ); + )); } } } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 1bd33eaf3c..4df550dbc6 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -648,6 +648,9 @@ impl super::Validator { } => { handle.check_dep(query)?; } + crate::Expression::CooperativeLoad { ref data, .. } => { + handle.check_dep(data.pointer)?.check_dep(data.stride)?; + } crate::Expression::CooperativeMultiplyAdd { a, b, c } => { handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; } @@ -839,16 +842,10 @@ impl super::Validator { validate_expr(result)?; Ok(()) } - crate::Statement::CooperativeLoadStore { - store: _, - target, - pointer, - stride, - row_major: _, - } => { + crate::Statement::CooperativeStore { target, ref data } => { validate_expr(target)?; - validate_expr(pointer)?; - validate_expr(stride)?; + validate_expr(data.pointer)?; + validate_expr(data.stride)?; Ok(()) } crate::Statement::Break diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 5d7b1bfb26..641af45b7f 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -5,8 +5,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - var c = coop_mat8x8(); - coopLoad(c, &ext[4]); + var c = coopLoad(&ext[4]); var d = coopMultiplyAdd(a, b, c); coopStore(d, &ext[0]); c = d; diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 051e07fd1e..d31e45cd6f 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -107,7 +107,7 @@ ( name: Some("c"), ty: 4, - init: Some(0), + init: None, ), ( name: Some("d"), @@ -116,98 +116,102 @@ ), ], expressions: [ - ZeroValue(4), - LocalVariable(0), - Load( - pointer: 1, - ), GlobalVariable(2), AccessIndex( - base: 3, + base: 0, index: 4, ), Literal(U32(8)), + CooperativeLoad( + columns: Eight, + rows: Eight, + role: C, + data: ( + pointer: 1, + stride: 2, + row_major: false, + ), + ), + LocalVariable(0), GlobalVariable(0), Load( - pointer: 6, + pointer: 5, ), GlobalVariable(1), Load( - pointer: 8, + pointer: 7, ), Load( - pointer: 1, + pointer: 4, ), CooperativeMultiplyAdd( - a: 7, - b: 9, - c: 10, + a: 6, + b: 8, + c: 9, ), LocalVariable(1), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(2), AccessIndex( - base: 14, + base: 13, index: 0, ), Literal(U32(8)), Load( - pointer: 12, + pointer: 11, ), ], named_expressions: {}, body: [ Emit(( - start: 2, - end: 3, + start: 1, + end: 2, )), Emit(( - start: 4, - end: 5, + start: 3, + end: 4, )), - CooperativeLoadStore( - store: false, - target: 2, + Store( pointer: 4, - stride: 5, - row_major: false, + value: 3, ), Emit(( - start: 7, - end: 8, + start: 6, + end: 7, )), Emit(( - start: 9, - end: 12, + start: 8, + end: 11, )), Store( - pointer: 12, - value: 11, + pointer: 11, + value: 10, ), Emit(( - start: 13, - end: 14, + start: 12, + end: 13, )), Emit(( - start: 15, - end: 16, + start: 14, + end: 15, )), - CooperativeLoadStore( - store: true, - target: 13, - pointer: 15, - stride: 16, - row_major: false, + CooperativeStore( + target: 12, + data: ( + pointer: 14, + stride: 15, + row_major: false, + ), ), Emit(( - start: 17, - end: 18, + start: 16, + end: 17, )), Store( - pointer: 1, - value: 17, + pointer: 4, + value: 16, ), Return( value: None, diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 051e07fd1e..d31e45cd6f 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -107,7 +107,7 @@ ( name: Some("c"), ty: 4, - init: Some(0), + init: None, ), ( name: Some("d"), @@ -116,98 +116,102 @@ ), ], expressions: [ - ZeroValue(4), - LocalVariable(0), - Load( - pointer: 1, - ), GlobalVariable(2), AccessIndex( - base: 3, + base: 0, index: 4, ), Literal(U32(8)), + CooperativeLoad( + columns: Eight, + rows: Eight, + role: C, + data: ( + pointer: 1, + stride: 2, + row_major: false, + ), + ), + LocalVariable(0), GlobalVariable(0), Load( - pointer: 6, + pointer: 5, ), GlobalVariable(1), Load( - pointer: 8, + pointer: 7, ), Load( - pointer: 1, + pointer: 4, ), CooperativeMultiplyAdd( - a: 7, - b: 9, - c: 10, + a: 6, + b: 8, + c: 9, ), LocalVariable(1), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(2), AccessIndex( - base: 14, + base: 13, index: 0, ), Literal(U32(8)), Load( - pointer: 12, + pointer: 11, ), ], named_expressions: {}, body: [ Emit(( - start: 2, - end: 3, + start: 1, + end: 2, )), Emit(( - start: 4, - end: 5, + start: 3, + end: 4, )), - CooperativeLoadStore( - store: false, - target: 2, + Store( pointer: 4, - stride: 5, - row_major: false, + value: 3, ), Emit(( - start: 7, - end: 8, + start: 6, + end: 7, )), Emit(( - start: 9, - end: 12, + start: 8, + end: 11, )), Store( - pointer: 12, - value: 11, + pointer: 11, + value: 10, ), Emit(( - start: 13, - end: 14, + start: 12, + end: 13, )), Emit(( - start: 15, - end: 16, + start: 14, + end: 15, )), - CooperativeLoadStore( - store: true, - target: 13, - pointer: 15, - stride: 16, - row_major: false, + CooperativeStore( + target: 12, + data: ( + pointer: 14, + stride: 15, + row_major: false, + ), ), Emit(( - start: 17, - end: 18, + start: 16, + end: 17, )), Store( - pointer: 1, - value: 17, + pointer: 4, + value: 16, ), Return( value: None, diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl index 7f7562ce9f..604ec4a169 100644 --- a/naga/tests/out/msl/wgsl-cooperative-matrix.msl +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -9,6 +9,12 @@ struct _mslBufferSizes { }; typedef float type_3[1]; +metal::simdgroup_float8x8 NagaCooperativeLoad(const device float* ptr, int stride, bool is_row_major) { + metal::simdgroup_float8x8 m; + simdgroup_load(m, ptr, stride, 0, is_row_major); + return m; +} + metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const thread metal::simdgroup_float8x8& a, const thread metal::simdgroup_float8x8& b, const thread metal::simdgroup_float8x8& c) { metal::simdgroup_float8x8 d; simdgroup_multiply_accumulate(d,a,b,c); @@ -22,17 +28,16 @@ kernel void main_( ) { metal::simdgroup_float8x8 a = {}; metal::simdgroup_float8x8 b = {}; - metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; + metal::simdgroup_float8x8 c = {}; metal::simdgroup_float8x8 d = {}; - metal::simdgroup_float8x8 _e2 = c; - simdgroup_load(_e2, &ext[4], 8u); - metal::simdgroup_float8x8 _e7 = a; - metal::simdgroup_float8x8 _e9 = b; - metal::simdgroup_float8x8 _e10 = c; - d = NagaCooperativeMultiplyAdd(_e7, _e9, _e10); - metal::simdgroup_float8x8 _e13 = d; - simdgroup_store(_e13, &ext[0], 8u); - metal::simdgroup_float8x8 _e17 = d; - c = _e17; + c = NagaCooperativeLoad(&ext[4], 8u, false); + metal::simdgroup_float8x8 _e6 = a; + metal::simdgroup_float8x8 _e8 = b; + metal::simdgroup_float8x8 _e9 = c; + d = NagaCooperativeMultiplyAdd(_e6, _e8, _e9); + metal::simdgroup_float8x8 _e12 = d; + simdgroup_store(_e12, &ext[0], 8u); + metal::simdgroup_float8x8 _e16 = d; + c = _e16; return; } diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index d02c1a3cb5..4f443448d7 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 47 +; Bound: 46 OpCapability Shader OpCapability CooperativeMatrixKHR OpCapability VulkanMemoryModel @@ -19,8 +19,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - var c = coop_mat8x8(); - coopLoad(c, &ext[4]); + var c = coopLoad(&ext[4]); var d = coopMultiplyAdd(a, b, c); coopStore(d, &ext[0]); c = d; @@ -30,7 +29,7 @@ OpName %15 "a" OpName %18 "b" OpName %21 "ext" OpName %25 "main" -OpName %30 "c" +OpName %29 "c" OpName %32 "d" OpDecorate %12 ArrayStride 4 OpDecorate %21 DescriptorSet 0 @@ -60,42 +59,41 @@ OpMemberDecorate %22 0 Offset 0 %21 = OpVariable %23 StorageBuffer %26 = OpTypeFunction %2 %27 = OpTypePointer StorageBuffer %12 -%29 = OpConstantNull %13 -%31 = OpTypePointer Function %13 +%30 = OpTypePointer Function %13 +%31 = OpConstantNull %13 %33 = OpConstantNull %13 -%36 = OpTypePointer StorageBuffer %4 -%37 = OpConstant %7 4 +%35 = OpTypePointer StorageBuffer %4 +%36 = OpConstant %7 4 %25 = OpFunction %2 None %26 %24 = OpLabel -%30 = OpVariable %31 Function %29 -%32 = OpVariable %31 Function %33 +%29 = OpVariable %30 Function %31 +%32 = OpVariable %30 Function %33 %28 = OpAccessChain %27 %21 %9 OpBranch %34 %34 = OpLabel -OpLine %3 1 1 -%35 = OpLoad %13 %30 -OpLine %3 9 18 +OpLine %3 8 23 +OpLine %3 8 13 +%37 = OpAccessChain %35 %28 %36 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 +OpLine %3 8 5 +OpStore %29 %38 +OpLine %3 9 29 +%39 = OpLoad %5 %15 +OpLine %3 9 13 +%40 = OpLoad %10 %18 +%41 = OpLoad %13 %29 +%42 = OpCooperativeMatrixMulAddKHR %13 %39 %40 %41 OpLine %3 9 5 -%38 = OpAccessChain %36 %28 %37 -%39 = OpCooperativeMatrixLoadKHR %13 %38 %11 %8 -OpStore %35 %39 -OpLine %3 10 29 -%40 = OpLoad %5 %15 -OpLine %3 10 13 -%41 = OpLoad %10 %18 -%42 = OpLoad %13 %30 -%43 = OpCooperativeMatrixMulAddKHR %13 %40 %41 %42 +OpStore %32 %42 +OpLine %3 1 1 +%43 = OpLoad %13 %32 +OpLine %3 10 19 OpLine %3 10 5 -OpStore %32 %43 +%44 = OpAccessChain %35 %28 %9 +OpCooperativeMatrixStoreKHR %44 %43 %11 %8 OpLine %3 1 1 -%44 = OpLoad %13 %32 -OpLine %3 11 19 +%45 = OpLoad %13 %32 OpLine %3 11 5 -%45 = OpAccessChain %36 %28 %9 -OpCooperativeMatrixStoreKHR %45 %44 %11 %8 -OpLine %3 1 1 -%46 = OpLoad %13 %32 -OpLine %3 12 5 -OpStore %30 %46 +OpStore %29 %45 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl index af7a6195b6..b4a6947762 100644 --- a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -5,18 +5,17 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - var c: coop_mat8x8 = coop_mat8x8(); + var c: coop_mat8x8; var d: coop_mat8x8; - let _e2 = c; - coopLoad(_e2, (&ext[4]), 8u); - let _e7 = a; - let _e9 = b; - let _e10 = c; - d = coopMultiplyAdd(_e7, _e9, _e10); - let _e13 = d; - coopStore(_e13, (&ext[0]), 8u); - let _e17 = d; - c = _e17; + c = coopLoad((&ext[4]), 8u); + let _e6 = a; + let _e8 = b; + let _e9 = c; + d = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d; + coopStore(_e12, (&ext[0]), 8u); + let _e16 = d; + c = _e16; return; } From 4fa9f0022dfd18f79fb8776f1d6f99d62d82673d Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 28 Sep 2025 01:11:50 -0700 Subject: [PATCH 12/12] coop: support generic argument on coopLoad --- naga/src/back/wgsl/writer.rs | 24 +++++++++++++++++-- naga/src/front/wgsl/error.rs | 6 +++++ naga/src/front/wgsl/lower/mod.rs | 24 +++++++++++++++---- naga/src/front/wgsl/parse/ast.rs | 1 + naga/src/front/wgsl/parse/mod.rs | 8 ++++++- naga/tests/in/wgsl/cooperative-matrix.wgsl | 2 +- .../out/spv/wgsl-cooperative-matrix.spvasm | 4 ++-- .../out/wgsl/wgsl-cooperative-matrix.wgsl | 2 +- 8 files changed, 59 insertions(+), 12 deletions(-) diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 9a116d7464..4e81a2f7cd 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1712,9 +1712,29 @@ impl Writer { | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} - Expression::CooperativeLoad { ref data, .. } => { + Expression::CooperativeLoad { + columns, + rows, + role, + ref data, + } => { let suffix = if data.row_major { "T" } else { "" }; - write!(self.out, "coopLoad{suffix}(")?; + let scalar = func_ctx.info[data.pointer] + .ty + .inner_with(&module.types) + .pointer_base_type() + .unwrap() + .inner_with(&module.types) + .scalar() + .unwrap(); + write!( + self.out, + "coopLoad{suffix}>(", + columns as u32, + rows as u32, + scalar.try_to_wgsl().unwrap(), + role, + )?; self.write_expr(module, data.pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, data.stride, func_ctx)?; diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f0d6a4b848..b4c5c9b99c 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -413,6 +413,7 @@ pub(crate) enum Error<'a> { span: Span, }, UnderspecifiedCooperativeMatrix, + InvalidCooperativeLoadType(Span), UnsupportedCooperativeScalar(Span), } @@ -1393,6 +1394,11 @@ impl<'a> Error<'a> { labels: vec![], notes: vec![format!("must be F32")], }, + Error::InvalidCooperativeLoadType(span) => ParseError { + message: "cooperative load should have a generic type for coop_mat".into(), + labels: vec![(span, "type needs the coop_mat<...>".into())], + notes: vec![format!("must be a valid cooperative type")], + }, Error::UnsupportedCooperativeScalar(span) => ParseError { message: "cooperative scalar type is not supported".into(), labels: vec![(span, "type needs the scalar type specified".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index fd005a3264..326455c8ff 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1905,6 +1905,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { stmt.span, function, arguments, + None, &mut ctx.as_expression(block, &mut emitter), true, )?; @@ -2227,9 +2228,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Expression::Call { ref function, ref arguments, + result_ty, } => { let handle = self - .call(span, function, arguments, ctx, false)? + .call(span, function, arguments, result_ty, ctx, false)? .ok_or(Error::FunctionReturnsVoid(function.span))?; return Ok(Typed::Plain(handle)); } @@ -2424,6 +2426,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span: Span, function: &ast::Ident<'source>, arguments: &[Handle>], + result_ty: Option<(Handle>, Span)>, ctx: &mut ExpressionContext<'source, '_, '_>, is_statement: bool, ) -> Result<'source, Option>> { @@ -3145,9 +3148,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let row_major = function.name.ends_with("T"); let mut args = ctx.prepare_args(arguments, 1, span); let pointer = self.expression(args.next()?, ctx)?; - //TODO: read from generic argument - let columns = crate::CooperativeSize::Eight; - let rows = crate::CooperativeSize::Eight; + let (matrix_ty, matrix_span) = result_ty.expect("generic argument"); + let (columns, rows, role) = match ctx.types[matrix_ty] { + ast::Type::CooperativeMatrix { + columns, + rows, + role, + .. + } => (columns, rows, role), + _ => { + return Err(Box::new(Error::InvalidCooperativeLoadType( + matrix_span, + ))) + } + }; let stride = if args.total_args > 1 { self.expression(args.next()?, ctx)? } else { @@ -3167,7 +3181,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::Expression::CooperativeLoad { columns, rows, - role: crate::CooperativeRole::C, //TODO + role, data: crate::CooperativeData { pointer, stride, diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index af05a84110..da093a2f06 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -487,6 +487,7 @@ pub enum Expression<'a> { Call { function: Ident<'a>, arguments: Vec>>, + result_ty: Option<(Handle>, Span)>, }, Index { base: Handle>, diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 576bd9c977..50275c5078 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -800,6 +800,11 @@ impl Parser { } // everything else must be handled later, since they can be hidden by user-defined functions. _ => { + let result_ty = if lexer.peek().0 == Token::Paren('<') { + Some(self.singular_generic(lexer, ctx)?) + } else { + None + }; let arguments = self.arguments(lexer, ctx)?; ctx.unresolved.insert(ast::Dependency { ident: name, @@ -811,6 +816,7 @@ impl Parser { span: name_span, }, arguments, + result_ty, } } }; @@ -959,7 +965,7 @@ impl Parser { } else if let Token::Paren('(') = lexer.peek().0 { self.pop_rule_span(lexer); return self.function_call(lexer, word, span, ctx); - } else if word == "bitcast" { + } else if ["bitcast", "coopLoad"].contains(&word) { self.pop_rule_span(lexer); return self.function_call(lexer, word, span, ctx); } else { diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 641af45b7f..06afb5cdf4 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -5,7 +5,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - var c = coopLoad(&ext[4]); + var c = coopLoad>(&ext[4]); var d = coopMultiplyAdd(a, b, c); coopStore(d, &ext[0]); c = d; diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 4f443448d7..f2c9b5ceb5 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -19,7 +19,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - var c = coopLoad(&ext[4]); + var c = coopLoad>(&ext[4]); var d = coopMultiplyAdd(a, b, c); coopStore(d, &ext[0]); c = d; @@ -71,7 +71,7 @@ OpMemberDecorate %22 0 Offset 0 %28 = OpAccessChain %27 %21 %9 OpBranch %34 %34 = OpLabel -OpLine %3 8 23 +OpLine %3 8 44 OpLine %3 8 13 %37 = OpAccessChain %35 %28 %36 %38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl index b4a6947762..183dc84ad7 100644 --- a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -8,7 +8,7 @@ fn main() { var c: coop_mat8x8; var d: coop_mat8x8; - c = coopLoad((&ext[4]), 8u); + c = coopLoad>((&ext[4]), 8u); let _e6 = a; let _e8 = b; let _e9 = c;