diff --git a/CHANGELOG.md b/CHANGELOG.md index 76e1ec02882..fb212891a94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,10 @@ SamplerDescriptor { - Using both the wgpu command encoding APIs and `CommandEncoder::as_hal_mut` on the same encoder will now result in a panic. - Allow `include_spirv!` and `include_spirv_raw!` macros to be used in constants and statics. By @clarfonthey in [#8250](https://github.com/gfx-rs/wgpu/pull/8250). +#### Naga + +- Prevent UB with invalid ray query calls on spirv. By @Vecvec in [#8390](https://github.com/gfx-rs/wgpu/pull/8390). + ### Bug Fixes #### naga diff --git a/naga-test/src/lib.rs b/naga-test/src/lib.rs index 078aa27c405..769026edf5e 100644 --- a/naga-test/src/lib.rs +++ b/naga-test/src/lib.rs @@ -114,6 +114,7 @@ pub struct SpirvOutParameters { pub separate_entry_points: bool, #[serde(deserialize_with = "deserialize_binding_map")] pub binding_map: naga::back::spv::BindingMap, + pub ray_query_initialization_tracking: bool, pub use_storage_input_output_16: bool, } impl Default for SpirvOutParameters { @@ -126,6 +127,7 @@ impl Default for SpirvOutParameters { force_point_size: false, clamp_frag_depth: false, separate_entry_points: false, + ray_query_initialization_tracking: true, use_storage_input_output_16: true, binding_map: naga::back::spv::BindingMap::default(), } @@ -159,6 +161,7 @@ impl SpirvOutParameters { binding_map: self.binding_map.clone(), zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill, force_loop_bounding: true, + ray_query_initialization_tracking: true, debug_info, use_storage_input_output_16: self.use_storage_input_output_16, } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 7758d86c414..95650e4c536 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -203,7 +203,7 @@ impl Writer { )); let clamp_id = self.id_gen.next(); - body.push(Instruction::ext_inst( + body.push(Instruction::ext_inst_gl_op( self.gl450_ext_inst_id, spirv::GLOp::FClamp, float_type_id, @@ -1026,7 +1026,7 @@ impl BlockContext<'_> { }; let max_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, max_op, result_type_id, @@ -1034,7 +1034,7 @@ impl BlockContext<'_> { &[arg0_id, arg1_id], )); - MathOp::Custom(Instruction::ext_inst( + MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, min_op, result_type_id, @@ -1068,7 +1068,7 @@ impl BlockContext<'_> { arg2_id = self.writer.get_constant_composite(ty, &self.temp_list); } - MathOp::Custom(Instruction::ext_inst( + MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::FClamp, result_type_id, @@ -1282,7 +1282,7 @@ impl BlockContext<'_> { &self.temp_list, )); - MathOp::Custom(Instruction::ext_inst( + MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::FMix, result_type_id, @@ -1339,7 +1339,7 @@ impl BlockContext<'_> { }; let lsb_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::FindILsb, result_type_id, @@ -1347,7 +1347,7 @@ impl BlockContext<'_> { &[arg0_id], )); - MathOp::Custom(Instruction::ext_inst( + MathOp::Custom(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, result_type_id, @@ -1388,7 +1388,7 @@ impl BlockContext<'_> { }; let msb_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, if width != 4 { spirv::GLOp::FindILsb @@ -1445,7 +1445,7 @@ impl BlockContext<'_> { // o = min(offset, w) let offset_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, u32_type, @@ -1465,7 +1465,7 @@ impl BlockContext<'_> { // c = min(count, tmp) let count_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, u32_type, @@ -1495,7 +1495,7 @@ impl BlockContext<'_> { // o = min(offset, w) let offset_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, u32_type, @@ -1515,7 +1515,7 @@ impl BlockContext<'_> { // c = min(count, tmp) let count_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, u32_type, @@ -1610,7 +1610,7 @@ impl BlockContext<'_> { }; block.body.push(match math_op { - MathOp::Ext(op) => Instruction::ext_inst( + MathOp::Ext(op) => Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, op, result_type_id, @@ -1621,7 +1621,13 @@ impl BlockContext<'_> { }); id } - crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id, + crate::Expression::LocalVariable(variable) => { + if let Some(rq_tracker) = self.function.ray_query_tracker_variables.get(&variable) { + self.ray_query_tracker_expr + .insert(expr_handle, rq_tracker.id); + } + self.function.variables[&variable].id + } crate::Expression::Load { pointer } => { self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)? } @@ -1772,6 +1778,10 @@ impl BlockContext<'_> { crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, crate::Expression::RayQueryGetIntersection { query, committed } => { let query_id = self.cached[query]; + let init_tracker_id = *self + .ray_query_tracker_expr + .get(&query) + .expect("not a cached ray query"); let func_id = self .writer .write_ray_query_get_intersection_function(committed, self.ir_module); @@ -1782,7 +1792,7 @@ impl BlockContext<'_> { intersection_type_id, id, func_id, - &[query_id], + &[query_id, init_tracker_id], )); id } @@ -2008,7 +2018,7 @@ impl BlockContext<'_> { let max_const_id = maybe_splat_const(self.writer, max_const_id); let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::FClamp, expr_type_id, @@ -2671,7 +2681,7 @@ impl BlockContext<'_> { }); let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, clamp_op, wide_vector_type_id, @@ -2765,7 +2775,7 @@ impl BlockContext<'_> { let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit)); let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, clamp_op, result_type_id, diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index 3aec1333f0c..78d7c79edfb 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -446,7 +446,7 @@ impl BlockContext<'_> { // and negative values in a single instruction: negative values of // `input_id` get treated as very large positive values. let restricted_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, type_id, @@ -580,7 +580,7 @@ impl BlockContext<'_> { // and negative values in a single instruction: negative values of // `coordinates` get treated as very large positive values. let restricted_coordinates_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, coordinates.type_id, @@ -923,7 +923,7 @@ impl BlockContext<'_> { // Clamp the coords to the calculated margins let clamped_coords_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::NClamp, vec2f_type_id, diff --git a/naga/src/back/spv/index.rs b/naga/src/back/spv/index.rs index 3a15ee88060..85caa9457f0 100644 --- a/naga/src/back/spv/index.rs +++ b/naga/src/back/spv/index.rs @@ -366,7 +366,7 @@ impl BlockContext<'_> { // One or the other of the index or length is dynamic, so emit code for // BoundsCheckPolicy::Restrict. let restricted_index_id = self.gen_id(); - block.body.push(Instruction::ext_inst( + block.body.push(Instruction::ext_inst_gl_op( self.writer.gl450_ext_inst_id, spirv::GLOp::UMin, self.writer.get_u32_type_id(), diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 788c3bc119a..1f6e748a89c 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -156,18 +156,28 @@ impl super::Instruction { instruction } - pub(super) fn ext_inst( + pub(super) fn ext_inst_gl_op( set_id: Word, op: spirv::GLOp, result_type_id: Word, id: Word, operands: &[Word], + ) -> Self { + Self::ext_inst(set_id, op as u32, result_type_id, id, operands) + } + + pub(super) fn ext_inst( + set_id: Word, + op: u32, + result_type_id: Word, + id: Word, + operands: &[Word], ) -> Self { let mut instruction = Self::new(Op::ExtInst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(set_id); - instruction.add_operand(op as u32); + instruction.add_operand(op); for operand in operands { instruction.add_operand(*operand) } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 4690dc71951..3d98042e138 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -151,6 +151,8 @@ struct Function { signature: Option, parameters: Vec, variables: crate::FastHashMap, LocalVariable>, + /// Map from a local variable that is a ray query to its u32 tracker. + ray_query_tracker_variables: crate::FastHashMap, LocalVariable>, /// List of local variables used as a counters to ensure that all loops are bounded. force_loop_bounding_vars: Vec, @@ -445,6 +447,16 @@ struct LookupFunctionType { return_type_id: Word, } +#[derive(Debug, PartialEq, Clone, Hash, Eq)] +enum LookupRayQueryFunction { + Initialize, + Proceed, + GenerateIntersection, + ConfirmIntersection, + GetVertexPositions { committed: bool }, + GetIntersection { committed: bool }, +} + #[derive(Debug)] enum Dimension { Scalar, @@ -685,6 +697,10 @@ struct BlockContext<'w> { expression_constness: ExpressionConstnessTracker, force_loop_bounding: bool, + + /// Hash from an expression whose type is a ray query / pointer to a ray query to its tracker. + /// Note: this is sparse, so can't be a handle vec + ray_query_tracker_expr: crate::FastHashMap, Word>, } impl BlockContext<'_> { @@ -741,6 +757,7 @@ pub struct Writer { /// The set of spirv extensions used. extensions_used: crate::FastIndexSet<&'static str>, + debug_strings: Vec, debugs: Vec, annotations: Vec, flags: WriterFlags, @@ -773,12 +790,15 @@ pub struct Writer { // Just a temporary list of SPIR-V ids temp_list: Vec, - ray_get_committed_intersection_function: Option, - ray_get_candidate_intersection_function: Option, + ray_query_functions: crate::FastHashMap, /// F16 I/O polyfill manager for handling `f16` input/output variables /// when `StorageInputOutput16` capability is not available. io_f16_polyfills: f16_polyfill::F16IoPolyfill, + + /// Non semantic debug printf extension `OpExtInstImport` + debug_printf: Option, + pub(crate) ray_query_initialization_tracking: bool, } bitflags::bitflags! { @@ -810,6 +830,26 @@ bitflags::bitflags! { /// /// [`BuiltIn::FragDepth`]: crate::BuiltIn::FragDepth const CLAMP_FRAG_DEPTH = 0x10; + + /// Instead of silently failing if the arguments to generate a ray query are + /// invalid, uses debug printf extension to print to the command line + /// + /// Note: VK_KHR_shader_non_semantic_info must be enabled. This will have no + /// effect if `options.ray_query_initialization_tracking` is set to false. + const PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL = 0x20; + } +} + +bitflags::bitflags! { + /// How far through a ray query are we + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub(super) struct RayQueryPoint: u32 { + /// Ray query has been successfully initialized. + const INITIALIZED = 1 << 0; + /// Proceed has been called on ray query. + const PROCEED = 1 << 1; + /// Proceed has returned false (have finished traversal). + const FINISHED_TRAVERSAL = 1 << 2; } } @@ -867,6 +907,10 @@ pub struct Options<'a> { /// to think the number of iterations is bounded. pub force_loop_bounding: bool, + /// if set, ray queries will get a variable to track their state to prevent + /// misuse. + pub ray_query_initialization_tracking: bool, + /// Whether to use the `StorageInputOutput16` capability for `f16` shader I/O. /// When false, `f16` I/O is polyfilled using `f32` types with conversions. pub use_storage_input_output_16: bool, @@ -891,6 +935,7 @@ impl Default for Options<'_> { bounds_check_policies: BoundsCheckPolicies::default(), zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill, force_loop_bounding: true, + ray_query_initialization_tracking: true, use_storage_input_output_16: true, debug_info: None, } diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index 05a55c78d83..c486c842872 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -2,13 +2,46 @@ Generating SPIR-V for ray query operations. */ -use alloc::vec; +use alloc::{vec, vec::Vec}; use super::{ Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType, Writer, }; -use crate::arena::Handle; +use crate::{arena::Handle, back::spv::LookupRayQueryFunction}; + +/// helper function to check if a particular flag is set in a u32. +fn write_ray_flags_contains_flags( + writer: &mut Writer, + block: &mut Block, + id: spirv::Word, + flag: u32, +) -> spirv::Word { + let bit_id = writer.get_constant_scalar(crate::Literal::U32(flag)); + let zero_id = writer.get_constant_scalar(crate::Literal::U32(0)); + let u32_type_id = writer.get_u32_type_id(); + let bool_ty = writer.get_bool_type_id(); + + let and_id = writer.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::BitwiseAnd, + u32_type_id, + and_id, + id, + bit_id, + )); + + let eq_id = writer.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::INotEqual, + bool_ty, + eq_id, + and_id, + zero_id, + )); + + eq_id +} impl Writer { pub(super) fn write_ray_query_get_intersection_function( @@ -16,13 +49,14 @@ impl Writer { is_committed: bool, ir_module: &crate::Module, ) -> spirv::Word { - if is_committed { - if let Some(func_id) = self.ray_get_committed_intersection_function { - return func_id; - } - } else if let Some(func_id) = self.ray_get_candidate_intersection_function { - return func_id; - }; + if let Some(&word) = + self.ray_query_functions + .get(&LookupRayQueryFunction::GetIntersection { + committed: is_committed, + }) + { + return word; + } let ray_intersection = ir_module.special_types.ray_intersection.unwrap(); let intersection_type_id = self.get_handle_type_id(ray_intersection); let intersection_pointer_type_id = @@ -57,7 +91,7 @@ impl Writer { let argument_type_id = self.get_ray_query_pointer_id(); let func_ty = self.get_function_type(LookupFunctionType { - parameter_type_ids: vec![argument_type_id], + parameter_type_ids: vec![argument_type_id, flag_pointer_type_id], return_type_id: intersection_type_id, }); @@ -77,10 +111,19 @@ impl Writer { handle_id: 0, }); + let intersection_tracker_id = self.id_gen.next(); + let instruction = + Instruction::function_parameter(flag_pointer_type_id, intersection_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + let label_id = self.id_gen.next(); let mut block = Block::new(label_id); let blank_intersection_id = self.id_gen.next(); + // This must be before everything else in the function. block.body.push(Instruction::variable( intersection_pointer_type_id, blank_intersection_id, @@ -93,14 +136,73 @@ impl Writer { } else { spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR } as _)); - let raw_kind_id = self.id_gen.next(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTypeKHR, + + let loaded_ray_query_tracker_id = self.id_gen.next(); + block.body.push(Instruction::load( flag_type_id, - raw_kind_id, - query_id, - intersection_id, + loaded_ray_query_tracker_id, + intersection_tracker_id, + None, + )); + let proceeded_id = write_ray_flags_contains_flags( + self, + &mut block, + loaded_ray_query_tracker_id, + super::RayQueryPoint::PROCEED.bits(), + ); + let finished_proceed_id = write_ray_flags_contains_flags( + self, + &mut block, + loaded_ray_query_tracker_id, + super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + ); + let proceed_finished_correct_id = if is_committed { + finished_proceed_id + } else { + let not_finished_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + not_finished_id, + finished_proceed_id, + )); + not_finished_id + }; + + let is_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + is_valid_id, + proceed_finished_correct_id, + proceeded_id, + )); + + let valid_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_id); + + let final_label_id = self.id_gen.next(); + let mut final_block = Block::new(final_label_id); + + block.body.push(Instruction::selection_merge( + final_label_id, + spirv::SelectionControl::NONE, )); + function.consume( + block, + Instruction::branch_conditional(is_valid_id, valid_id, final_label_id), + ); + + let raw_kind_id = self.id_gen.next(); + valid_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + raw_kind_id, + query_id, + intersection_id, + )); let kind_id = if is_committed { // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType` raw_kind_id @@ -111,7 +213,7 @@ impl Writer { spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _, )); - block.body.push(Instruction::binary( + valid_block.body.push(Instruction::binary( spirv::Op::IEqual, self.get_bool_type_id(), condition_id, @@ -119,7 +221,7 @@ impl Writer { committed_triangle_kind_id, )); let kind_id = self.id_gen.next(); - block.body.push(Instruction::select( + valid_block.body.push(Instruction::select( flag_type_id, kind_id, condition_id, @@ -134,20 +236,20 @@ impl Writer { }; let idx_id = self.get_index_constant(0); let access_idx = self.id_gen.next(); - block.body.push(Instruction::access_chain( + valid_block.body.push(Instruction::access_chain( flag_pointer_type_id, access_idx, blank_intersection_id, &[idx_id], )); - block + valid_block .body .push(Instruction::store(access_idx, kind_id, None)); let not_none_comp_id = self.id_gen.next(); let none_id = self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _)); - block.body.push(Instruction::binary( + valid_block.body.push(Instruction::binary( spirv::Op::INotEqual, self.get_bool_type_id(), not_none_comp_id, @@ -158,16 +260,20 @@ impl Writer { let not_none_label_id = self.id_gen.next(); let mut not_none_block = Block::new(not_none_label_id); - let final_label_id = self.id_gen.next(); - let mut final_block = Block::new(final_label_id); + let outer_merge_label_id = self.id_gen.next(); + let outer_merge_block = Block::new(outer_merge_label_id); - block.body.push(Instruction::selection_merge( - final_label_id, + valid_block.body.push(Instruction::selection_merge( + outer_merge_label_id, spirv::SelectionControl::NONE, )); function.consume( - block, - Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id), + valid_block, + Instruction::branch_conditional( + not_none_comp_id, + not_none_label_id, + outer_merge_label_id, + ), ); let instance_custom_index_id = self.id_gen.next(); @@ -426,7 +532,8 @@ impl Writer { .body .push(Instruction::store(access_idx, front_face_id, None)); function.consume(tri_block, Instruction::branch(merge_label_id)); - function.consume(merge_block, Instruction::branch(final_label_id)); + function.consume(merge_block, Instruction::branch(outer_merge_label_id)); + function.consume(outer_merge_block, Instruction::branch(final_label_id)); let loaded_blank_intersection_id = self.id_gen.next(); final_block.body.push(Instruction::load( @@ -441,151 +548,1312 @@ impl Writer { ); function.to_words(&mut self.logical_layout.function_definitions); - if is_committed { - self.ray_get_committed_intersection_function = Some(func_id); - } else { - self.ray_get_candidate_intersection_function = Some(func_id); - } + self.ray_query_functions.insert( + LookupRayQueryFunction::GetIntersection { + committed: is_committed, + }, + func_id, + ); func_id } -} -impl BlockContext<'_> { - pub(super) fn write_ray_query_function( - &mut self, - query: Handle, - function: &crate::RayQueryFunction, - block: &mut Block, - ) { - let query_id = self.cached[query]; - match *function { - crate::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } => { - //Note: composite extract indices and types must match `generate_ray_desc_type` - let desc_id = self.cached[descriptor]; - let acc_struct_id = self.get_handle_id(acceleration_structure); + fn write_ray_query_initialize(&mut self, ir_module: &crate::Module) -> spirv::Word { + if let Some(&word) = self + .ray_query_functions + .get(&LookupRayQueryFunction::Initialize) + { + return word; + } - let flag_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); - let ray_flags_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - ray_flags_id, - desc_id, - &[0], - )); - let cull_mask_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - cull_mask_id, - desc_id, - &[1], - )); + let ray_query_type_id = self.get_ray_query_pointer_id(); + let acceleration_structure_type_id = + self.get_localtype_id(super::LocalType::AccelerationStructure); + let ray_desc_type_id = self.get_handle_type_id( + ir_module + .special_types + .ray_desc + .expect("ray desc should be set if ray queries are being initialized"), + ); - let scalar_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32)); - let tmin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmin_id, - desc_id, - &[2], - )); - let tmax_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmax_id, - desc_id, - &[3], - )); + let u32_ty = self.get_u32_type_id(); + let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); - let vector_type_id = self.get_numeric_type_id(NumericType::Vector { - size: crate::VectorSize::Tri, - scalar: crate::Scalar::F32, - }); - let ray_origin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_origin_id, - desc_id, - &[4], - )); - let ray_dir_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_dir_id, - desc_id, - &[5], - )); + let bool_type_id = self.get_bool_type_id(); + let bool_vec3_type_id = self.get_vec3_bool_type_id(); - block.body.push(Instruction::ray_query_initialize( - query_id, - acc_struct_id, - ray_flags_id, - cull_mask_id, - ray_origin_id, - tmin_id, - ray_dir_id, - tmax_id, + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![ + ray_query_type_id, + acceleration_structure_type_id, + ray_desc_type_id, + u32_ptr_ty, + ], + return_type_id: self.void_type, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + self.void_type, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_query_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let acceleration_structure_id = self.id_gen.next(); + let instruction = Instruction::function_parameter( + acceleration_structure_type_id, + acceleration_structure_id, + ); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + + let desc_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_desc_type_id, desc_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 2, + }); + + let init_tracker_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(u32_ptr_ty, init_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 3, + }); + + let label_id = self.id_gen.next(); + let mut block = Block::new(label_id); + + let flag_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); + + //Note: composite extract indices and types must match `generate_ray_desc_type` + let ray_flags_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32)); + let tmin_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }); + let ray_origin_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + let valid_id = self.ray_query_initialization_tracking.then(||{ + let tmin_le_tmax_id = self.id_gen.next(); + // Because this checks if tmin and tmax are ordered too (i.e: not NaN), there is no need for an additional check. + block.body.push(Instruction::binary( + spirv::Op::FOrdLessThanEqual, + bool_type_id, + tmin_le_tmax_id, + tmin_id, + tmax_id, + )); + + let tmin_ge_zero_id = self.id_gen.next(); + let zero_id = self.get_constant_scalar(crate::Literal::F32(0.0)); + block.body.push(Instruction::binary( + spirv::Op::FOrdGreaterThanEqual, + bool_type_id, + tmin_ge_zero_id, + tmin_id, + zero_id, + )); + + let ray_origin_infinite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::IsInf, + bool_vec3_type_id, + ray_origin_infinite_id, + ray_origin_id, + )); + let any_ray_origin_infinite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::Any, + bool_type_id, + any_ray_origin_infinite_id, + ray_origin_infinite_id, + )); + + let ray_origin_nan_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::IsNan, + bool_vec3_type_id, + ray_origin_nan_id, + ray_origin_id, + )); + let any_ray_origin_nan_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::Any, + bool_type_id, + any_ray_origin_nan_id, + ray_origin_nan_id, + )); + + let ray_origin_not_finite_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalOr, + bool_type_id, + ray_origin_not_finite_id, + any_ray_origin_nan_id, + any_ray_origin_infinite_id, + )); + + let all_ray_origin_finite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + all_ray_origin_finite_id, + ray_origin_not_finite_id, + )); + + let ray_dir_infinite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::IsInf, + bool_vec3_type_id, + ray_dir_infinite_id, + ray_dir_id, + )); + let any_ray_dir_infinite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::Any, + bool_type_id, + any_ray_dir_infinite_id, + ray_dir_infinite_id, + )); + + let ray_dir_nan_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::IsNan, + bool_vec3_type_id, + ray_dir_nan_id, + ray_dir_id, + )); + let any_ray_dir_nan_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::Any, + bool_type_id, + any_ray_dir_nan_id, + ray_dir_nan_id, + )); + + let ray_dir_not_finite_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalOr, + bool_type_id, + ray_dir_not_finite_id, + any_ray_dir_nan_id, + any_ray_dir_infinite_id, + )); + + let all_ray_dir_finite_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + all_ray_dir_finite_id, + ray_dir_not_finite_id, + )); + + /// Writes spirv to check that less than two booleans are true + /// + /// For each boolean: removes it, `and`s it with all others (i.e for all possible combinations of two booleans in the list checks to see if both are true). + /// Then `or`s all of these checks together. This produces whether two or more booleans are true. + fn write_less_than_2_true( + writer: &mut Writer, + block: &mut Block, + mut bools: Vec, + ) -> spirv::Word { + assert!(bools.len() > 1, "Must have multiple booleans!"); + let bool_ty = writer.get_bool_type_id(); + let mut each_two_true = Vec::new(); + while let Some(last_bool) = bools.pop() { + for &bool in &bools { + let both_true_id = writer.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_ty, + both_true_id, + last_bool, + bool, + )); + each_two_true.push(both_true_id); + } + } + let mut all_or_id = each_two_true.pop().expect("since this must have multiple booleans, there must be at least one thing in `each_two_true"); + for two_true in each_two_true { + let new_all_or_id = writer.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalOr, + bool_ty, + new_all_or_id, + all_or_id, + two_true, + )); + all_or_id = new_all_or_id; + } + + let less_than_two_id = writer.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_ty, + less_than_two_id, + all_or_id, )); + less_than_two_id } - crate::RayQueryFunction::Proceed { result } => { - let id = self.gen_id(); - self.cached[result] = id; - let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); - block - .body - .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); - } - crate::RayQueryFunction::GenerateIntersection { hit_t } => { - let hit_id = self.cached[hit_t]; - block - .body - .push(Instruction::ray_query_generate_intersection( - query_id, hit_id, - )); + let contains_skip_triangles = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::SKIP_TRIANGLES.bits(), + ); + let contains_skip_aabbs = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::SKIP_AABBS.bits(), + ); + + let not_contain_skip_triangles_aabbs = write_less_than_2_true( + self, + &mut block, + vec![contains_skip_triangles, contains_skip_aabbs], + ); + + let contains_cull_back = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::CULL_BACK_FACING.bits(), + ); + let contains_cull_front = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::CULL_FRONT_FACING.bits(), + ); + + let not_contain_skip_triangles_cull = write_less_than_2_true( + self, + &mut block, + vec![ + contains_skip_triangles, + contains_cull_back, + contains_cull_front, + ], + ); + + let contains_opaque = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::FORCE_OPAQUE.bits(), + ); + let contains_no_opaque = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::FORCE_NO_OPAQUE.bits(), + ); + let contains_cull_opaque = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::CULL_OPAQUE.bits(), + ); + let contains_cull_no_opaque = write_ray_flags_contains_flags( + self, + &mut block, + ray_flags_id, + crate::RayFlag::CULL_NO_OPAQUE.bits(), + ); + + let not_contain_multiple_opaque = write_less_than_2_true( + self, + &mut block, + vec![ + contains_opaque, + contains_no_opaque, + contains_cull_opaque, + contains_cull_no_opaque, + ], + ); + + let tmin_tmax_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + tmin_tmax_valid_id, + tmin_le_tmax_id, + tmin_ge_zero_id, + )); + + let origin_dir_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + origin_dir_valid_id, + all_ray_origin_finite_id, + all_ray_dir_finite_id, + )); + + let flags_skip_tri_aabbs_tri_cull_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + flags_skip_tri_aabbs_tri_cull_id, + not_contain_skip_triangles_aabbs, + not_contain_skip_triangles_cull, + )); + let flags_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + flags_valid_id, + flags_skip_tri_aabbs_tri_cull_id, + not_contain_multiple_opaque, + )); + + let tmin_tmax_origin_dir_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + tmin_tmax_origin_dir_valid_id, + tmin_tmax_valid_id, + origin_dir_valid_id, + )); + + let all_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + all_valid_id, + tmin_tmax_origin_dir_valid_id, + flags_valid_id, + )); + + all_valid_id + }); + + let merge_label_id = self.id_gen.next(); + let merge_block = Block::new(merge_label_id); + + // NOTE: this block will be unreachable if initialization tracking is set to false. + let invalid_label_id = self.id_gen.next(); + let mut invalid_block = Block::new(invalid_label_id); + + let valid_label_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_label_id); + + match valid_id { + Some(all_valid_id) => { + block.body.push(Instruction::selection_merge( + merge_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + block, + Instruction::branch_conditional(all_valid_id, valid_label_id, invalid_label_id), + ); } - crate::RayQueryFunction::ConfirmIntersection => { - block - .body - .push(Instruction::ray_query_confirm_intersection(query_id)); + None => { + function.consume(block, Instruction::branch(valid_label_id)); } - crate::RayQueryFunction::Terminate => {} } - } - pub(super) fn write_ray_query_return_vertex_position( - &mut self, - query: Handle, - block: &mut Block, - is_committed: bool, - ) -> spirv::Word { - let query_id = self.cached[query]; - let id = self.gen_id(); - let ray_vertex_return_ty = self - .ir_module - .special_types - .ray_vertex_return - .expect("type should have been populated"); - let ray_vertex_return_ty_id = self.writer.get_handle_type_id(ray_vertex_return_ty); - let intersection_id = - self.writer - .get_constant_scalar(crate::Literal::U32(if is_committed { - spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR - } else { - spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR - } as _)); - block + valid_block.body.push(Instruction::ray_query_initialize( + query_id, + acceleration_structure_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + + let const_initialized = self.get_constant_scalar(crate::Literal::U32( + super::RayQueryPoint::INITIALIZED.bits(), + )); + valid_block .body - .push(Instruction::ray_query_return_vertex_position( - ray_vertex_return_ty_id, - id, - query_id, - intersection_id, - )); - id + .push(Instruction::store(init_tracker_id, const_initialized, None)); + + function.consume(valid_block, Instruction::branch(merge_label_id)); + + if self + .flags + .contains(super::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL) + { + self.write_debug_printf( + &mut invalid_block, + "Naga ignored invalid arguments to rayQueryInitialize with flags: %u t_min: %f t_max: %f origin: %v4f dir: %v4f", + &[ + ray_flags_id, + tmin_id, + tmax_id, + ray_origin_id, + ray_dir_id, + ], + ); + } + + function.consume(invalid_block, Instruction::branch(merge_label_id)); + + function.consume(merge_block, Instruction::return_void()); + + function.to_words(&mut self.logical_layout.function_definitions); + + self.ray_query_functions + .insert(LookupRayQueryFunction::Initialize, func_id); + func_id + } + + fn write_ray_query_proceed(&mut self) -> spirv::Word { + if let Some(&word) = self + .ray_query_functions + .get(&LookupRayQueryFunction::Proceed) + { + return word; + } + + let ray_query_type_id = self.get_ray_query_pointer_id(); + + let u32_ty = self.get_u32_type_id(); + let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); + + let bool_type_id = self.get_bool_type_id(); + let bool_ptr_ty = self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function); + + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![ray_query_type_id, u32_ptr_ty], + return_type_id: bool_type_id, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + bool_type_id, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_query_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let init_tracker_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(u32_ptr_ty, init_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + + let block_id = self.id_gen.next(); + let mut block = Block::new(block_id); + + // TODO: perhaps this could be replaced with an OpPhi? + let proceeded_id = self.id_gen.next(); + let const_false = self.get_constant_scalar(crate::Literal::Bool(false)); + block.body.push(Instruction::variable( + bool_ptr_ty, + proceeded_id, + spirv::StorageClass::Function, + Some(const_false), + )); + + let initialized_tracker_id = self.id_gen.next(); + block.body.push(Instruction::load( + u32_ty, + initialized_tracker_id, + init_tracker_id, + None, + )); + + let merge_id = self.id_gen.next(); + let mut merge_block = Block::new(merge_id); + + let valid_block_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_block_id); + + let instruction = if self.ray_query_initialization_tracking { + let is_initialized = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::INITIALIZED.bits(), + ); + + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + Instruction::branch_conditional(is_initialized, valid_block_id, merge_id) + } else { + Instruction::branch(valid_block_id) + }; + + function.consume(block, instruction); + + let has_proceeded = self.id_gen.next(); + valid_block.body.push(Instruction::ray_query_proceed( + bool_type_id, + has_proceeded, + query_id, + )); + + valid_block + .body + .push(Instruction::store(proceeded_id, has_proceeded, None)); + + let add_flag_finished = self.get_constant_scalar(crate::Literal::U32( + (super::RayQueryPoint::PROCEED | super::RayQueryPoint::FINISHED_TRAVERSAL).bits(), + )); + let add_flag_continuing = + self.get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::PROCEED.bits())); + + let add_flags_id = self.id_gen.next(); + valid_block.body.push(Instruction::select( + u32_ty, + add_flags_id, + has_proceeded, + add_flag_continuing, + add_flag_finished, + )); + let final_flags = self.id_gen.next(); + valid_block.body.push(Instruction::binary( + spirv::Op::BitwiseOr, + u32_ty, + final_flags, + initialized_tracker_id, + add_flags_id, + )); + valid_block + .body + .push(Instruction::store(init_tracker_id, final_flags, None)); + + function.consume(valid_block, Instruction::branch(merge_id)); + + let loaded_proceeded_id = self.id_gen.next(); + merge_block.body.push(Instruction::load( + bool_type_id, + loaded_proceeded_id, + proceeded_id, + None, + )); + + function.consume(merge_block, Instruction::return_value(loaded_proceeded_id)); + + function.to_words(&mut self.logical_layout.function_definitions); + + self.ray_query_functions + .insert(LookupRayQueryFunction::Proceed, func_id); + func_id + } + + fn write_ray_query_generate_intersection(&mut self) -> spirv::Word { + if let Some(&word) = self + .ray_query_functions + .get(&LookupRayQueryFunction::GenerateIntersection) + { + return word; + } + + let ray_query_type_id = self.get_ray_query_pointer_id(); + + let u32_ty = self.get_u32_type_id(); + let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); + + let f32_type_id = self.get_f32_type_id(); + + let bool_type_id = self.get_bool_type_id(); + + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![ray_query_type_id, u32_ptr_ty, f32_type_id], + return_type_id: self.void_type, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + self.void_type, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_query_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let init_tracker_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(u32_ptr_ty, init_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + + let depth_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(f32_type_id, depth_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 2, + }); + + let block_id = self.id_gen.next(); + let mut block = Block::new(block_id); + + let valid_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_id); + + let final_label_id = self.id_gen.next(); + let final_block = Block::new(final_label_id); + + let instruction = if self.ray_query_initialization_tracking { + let initialized_tracker_id = self.id_gen.next(); + block.body.push(Instruction::load( + u32_ty, + initialized_tracker_id, + init_tracker_id, + None, + )); + + let proceeded_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::PROCEED.bits(), + ); + let finished_proceed_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + ); + // TODO: Is double calling this invalid? Can't find anything to suggest so. + let not_finished_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + not_finished_id, + finished_proceed_id, + )); + + let is_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + is_valid_id, + not_finished_id, + proceeded_id, + )); + + block.body.push(Instruction::selection_merge( + final_label_id, + spirv::SelectionControl::NONE, + )); + + Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) + } else { + Instruction::branch(valid_id) + }; + + function.consume(block, instruction); + + let intersection_id = self.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _, + )); + let raw_kind_id = self.id_gen.next(); + valid_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + u32_ty, + raw_kind_id, + query_id, + intersection_id, + )); + + let candidate_aabb_id = self.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR as _, + )); + let intersection_aabb_id = self.id_gen.next(); + valid_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + intersection_aabb_id, + raw_kind_id, + candidate_aabb_id, + )); + + let generate_label_id = self.id_gen.next(); + let mut generate_block = Block::new(generate_label_id); + + let merge_label_id = self.id_gen.next(); + let merge_block = Block::new(merge_label_id); + + valid_block.body.push(Instruction::selection_merge( + merge_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + valid_block, + Instruction::branch_conditional( + intersection_aabb_id, + generate_label_id, + merge_label_id, + ), + ); + + generate_block + .body + .push(Instruction::ray_query_generate_intersection( + query_id, depth_id, + )); + + function.consume(generate_block, Instruction::branch(merge_label_id)); + function.consume(merge_block, Instruction::branch(final_label_id)); + + function.consume(final_block, Instruction::return_void()); + + function.to_words(&mut self.logical_layout.function_definitions); + + self.ray_query_functions + .insert(LookupRayQueryFunction::GenerateIntersection, func_id); + func_id + } + + fn write_ray_query_confirm_intersection(&mut self) -> spirv::Word { + if let Some(&word) = self + .ray_query_functions + .get(&LookupRayQueryFunction::ConfirmIntersection) + { + return word; + } + + let ray_query_type_id = self.get_ray_query_pointer_id(); + + let u32_ty = self.get_u32_type_id(); + let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); + + let bool_type_id = self.get_bool_type_id(); + + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![ray_query_type_id, u32_ptr_ty], + return_type_id: self.void_type, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + self.void_type, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_query_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let init_tracker_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(u32_ptr_ty, init_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + + let block_id = self.id_gen.next(); + let mut block = Block::new(block_id); + + let valid_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_id); + + let final_label_id = self.id_gen.next(); + let final_block = Block::new(final_label_id); + + let instruction = if self.ray_query_initialization_tracking { + let initialized_tracker_id = self.id_gen.next(); + block.body.push(Instruction::load( + u32_ty, + initialized_tracker_id, + init_tracker_id, + None, + )); + + let proceeded_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::PROCEED.bits(), + ); + let finished_proceed_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + ); + // TODO: Is double calling this invalid? Can't find anything to suggest so, but it seems strange not to + let not_finished_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + not_finished_id, + finished_proceed_id, + )); + + let is_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + is_valid_id, + not_finished_id, + proceeded_id, + )); + + block.body.push(Instruction::selection_merge( + final_label_id, + spirv::SelectionControl::NONE, + )); + + Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) + } else { + Instruction::branch(valid_id) + }; + + function.consume(block, instruction); + + let intersection_id = self.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as _, + )); + let raw_kind_id = self.id_gen.next(); + valid_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + u32_ty, + raw_kind_id, + query_id, + intersection_id, + )); + + let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR as _, + )); + let intersection_tri_id = self.id_gen.next(); + valid_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + intersection_tri_id, + raw_kind_id, + candidate_tri_id, + )); + + let generate_label_id = self.id_gen.next(); + let mut generate_block = Block::new(generate_label_id); + + let merge_label_id = self.id_gen.next(); + let merge_block = Block::new(merge_label_id); + + valid_block.body.push(Instruction::selection_merge( + merge_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + valid_block, + Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id), + ); + + generate_block + .body + .push(Instruction::ray_query_confirm_intersection(query_id)); + + function.consume(generate_block, Instruction::branch(merge_label_id)); + function.consume(merge_block, Instruction::branch(final_label_id)); + + function.consume(final_block, Instruction::return_void()); + + self.ray_query_functions + .insert(LookupRayQueryFunction::ConfirmIntersection, func_id); + + function.to_words(&mut self.logical_layout.function_definitions); + + func_id + } + + fn write_ray_query_get_vertex_positions( + &mut self, + is_committed: bool, + ir_module: &crate::Module, + ) -> spirv::Word { + if let Some(&word) = + self.ray_query_functions + .get(&LookupRayQueryFunction::GetVertexPositions { + committed: is_committed, + }) + { + return word; + } + + let (committed_ty, committed_tri_ty) = if is_committed { + ( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as u32, + spirv::RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR + as u32, + ) + } else { + ( + spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR as u32, + spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR + as u32, + ) + }; + + let ray_query_type_id = self.get_ray_query_pointer_id(); + + let u32_ty = self.get_u32_type_id(); + let u32_ptr_ty = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Function); + + let rq_get_vertex_positions_ty_id = self.get_handle_type_id( + *ir_module + .special_types + .ray_vertex_return + .as_ref() + .expect("must be generated when reading in get vertex position"), + ); + let ptr_return_ty = + self.get_pointer_type_id(rq_get_vertex_positions_ty_id, spirv::StorageClass::Function); + + let bool_type_id = self.get_bool_type_id(); + + let func_ty = self.get_function_type(LookupFunctionType { + parameter_type_ids: vec![ray_query_type_id, u32_ptr_ty], + return_type_id: self.void_type, + }); + + let mut function = Function::default(); + let func_id = self.id_gen.next(); + function.signature = Some(Instruction::function( + self.void_type, + func_id, + spirv::FunctionControl::empty(), + func_ty, + )); + + let query_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(ray_query_type_id, query_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 0, + }); + + let init_tracker_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(u32_ptr_ty, init_tracker_id); + function.parameters.push(FunctionArgument { + instruction, + handle_id: 1, + }); + + let block_id = self.id_gen.next(); + let mut block = Block::new(block_id); + + let return_id = self.id_gen.next(); + block.body.push(Instruction::variable( + ptr_return_ty, + return_id, + spirv::StorageClass::Function, + Some(self.get_constant_null(rq_get_vertex_positions_ty_id)), + )); + + let valid_id = self.id_gen.next(); + let mut valid_block = Block::new(valid_id); + + let final_label_id = self.id_gen.next(); + let mut final_block = Block::new(final_label_id); + + let instruction = if self.ray_query_initialization_tracking { + let initialized_tracker_id = self.id_gen.next(); + block.body.push(Instruction::load( + u32_ty, + initialized_tracker_id, + init_tracker_id, + None, + )); + + let proceeded_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::PROCEED.bits(), + ); + let finished_proceed_id = write_ray_flags_contains_flags( + self, + &mut block, + initialized_tracker_id, + super::RayQueryPoint::FINISHED_TRAVERSAL.bits(), + ); + + let correct_finish_id = if is_committed { + finished_proceed_id + } else { + let not_finished_id = self.id_gen.next(); + block.body.push(Instruction::unary( + spirv::Op::LogicalNot, + bool_type_id, + not_finished_id, + finished_proceed_id, + )); + not_finished_id + }; + + let is_valid_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + bool_type_id, + is_valid_id, + correct_finish_id, + proceeded_id, + )); + block.body.push(Instruction::selection_merge( + final_label_id, + spirv::SelectionControl::NONE, + )); + Instruction::branch_conditional(is_valid_id, valid_id, final_label_id) + } else { + Instruction::branch(valid_id) + }; + + function.consume(block, instruction); + + let intersection_id = self.get_constant_scalar(crate::Literal::U32(committed_ty)); + let raw_kind_id = self.id_gen.next(); + valid_block + .body + .push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + u32_ty, + raw_kind_id, + query_id, + intersection_id, + )); + + let candidate_tri_id = self.get_constant_scalar(crate::Literal::U32(committed_tri_ty)); + let intersection_tri_id = self.id_gen.next(); + valid_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + intersection_tri_id, + raw_kind_id, + candidate_tri_id, + )); + + let generate_label_id = self.id_gen.next(); + let mut vertex_return_block = Block::new(generate_label_id); + + let merge_label_id = self.id_gen.next(); + let merge_block = Block::new(merge_label_id); + + valid_block.body.push(Instruction::selection_merge( + merge_label_id, + spirv::SelectionControl::NONE, + )); + function.consume( + valid_block, + Instruction::branch_conditional(intersection_tri_id, generate_label_id, merge_label_id), + ); + + let vertices_id = self.id_gen.next(); + vertex_return_block + .body + .push(Instruction::ray_query_return_vertex_position( + rq_get_vertex_positions_ty_id, + vertices_id, + query_id, + intersection_id, + )); + vertex_return_block + .body + .push(Instruction::store(return_id, vertices_id, None)); + + function.consume(vertex_return_block, Instruction::branch(merge_label_id)); + function.consume(merge_block, Instruction::branch(final_label_id)); + + let loaded_pos_id = self.id_gen.next(); + final_block.body.push(Instruction::load( + rq_get_vertex_positions_ty_id, + loaded_pos_id, + return_id, + None, + )); + + function.consume(final_block, Instruction::return_value(loaded_pos_id)); + + self.ray_query_functions.insert( + LookupRayQueryFunction::GetVertexPositions { + committed: is_committed, + }, + func_id, + ); + func_id + } +} + +impl BlockContext<'_> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + let init_tracker_id = *self + .ray_query_tracker_expr + .get(&query) + .expect("not a cached ray query"); + + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_handle_id(acceleration_structure); + + let func = self.writer.write_ray_query_initialize(self.ir_module); + + let func_id = self.gen_id(); + block.body.push(Instruction::function_call( + self.writer.void_type, + func_id, + func, + &[query_id, acc_struct_id, desc_id, init_tracker_id], + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + + let bool_ty = self.writer.get_bool_type_id(); + + let func_id = self.writer.write_ray_query_proceed(); + block.body.push(Instruction::function_call( + bool_ty, + id, + func_id, + &[query_id, init_tracker_id], + )); + } + crate::RayQueryFunction::GenerateIntersection { hit_t } => { + let hit_id = self.cached[hit_t]; + + let func_id = self.writer.write_ray_query_generate_intersection(); + + let func_call_id = self.gen_id(); + block.body.push(Instruction::function_call( + self.writer.void_type, + func_call_id, + func_id, + &[query_id, init_tracker_id, hit_id], + )); + } + crate::RayQueryFunction::ConfirmIntersection => { + let func_id = self.writer.write_ray_query_confirm_intersection(); + + let func_call_id = self.gen_id(); + block.body.push(Instruction::function_call( + self.writer.void_type, + func_call_id, + func_id, + &[query_id, init_tracker_id], + )); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_return_vertex_position( + &mut self, + query: Handle, + block: &mut Block, + is_committed: bool, + ) -> spirv::Word { + let fn_id = self + .writer + .write_ray_query_get_vertex_positions(is_committed, self.ir_module); + + let query_id = self.cached[query]; + let init_tracker_id = *self + .ray_query_tracker_expr + .get(&query) + .expect("not a cached ray query"); + + let func_call_id = self.gen_id(); + block.body.push(Instruction::function_call( + self.writer.void_type, + func_call_id, + fn_id, + &[query_id, init_tracker_id], + )); + func_call_id } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index c86a53c6ef8..b65db628548 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -35,6 +35,9 @@ impl Function { for local_var in self.variables.values() { local_var.instruction.to_words(sink); } + for local_var in self.ray_query_tracker_variables.values() { + local_var.instruction.to_words(sink); + } for local_var in self.force_loop_bounding_vars.iter() { local_var.instruction.to_words(sink); } @@ -71,12 +74,14 @@ impl Writer { capabilities_available: options.capabilities.clone(), capabilities_used, extensions_used: crate::FastIndexSet::default(), + debug_strings: vec![], debugs: vec![], annotations: vec![], flags: options.flags, bounds_check_policies: options.bounds_check_policies, zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory, force_loop_bounding: options.force_loop_bounding, + ray_query_initialization_tracking: options.ray_query_initialization_tracking, use_storage_input_output_16: options.use_storage_input_output_16, void_type, lookup_type: crate::FastHashMap::default(), @@ -91,11 +96,11 @@ impl Writer { saved_cached: CachedExpressions::default(), gl450_ext_inst_id, temp_list: Vec::new(), - ray_get_committed_intersection_function: None, - ray_get_candidate_intersection_function: None, + ray_query_functions: crate::FastHashMap::default(), io_f16_polyfills: super::f16_polyfill::F16IoPolyfill::new( options.use_storage_input_output_16, ), + debug_printf: None, }) } @@ -147,6 +152,7 @@ impl Writer { bounds_check_policies: self.bounds_check_policies, zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, force_loop_bounding: self.force_loop_bounding, + ray_query_initialization_tracking: self.ray_query_initialization_tracking, use_storage_input_output_16: self.use_storage_input_output_16, capabilities_available: take(&mut self.capabilities_available), fake_missing_bindings: self.fake_missing_bindings, @@ -162,6 +168,7 @@ impl Writer { extensions_used: take(&mut self.extensions_used).recycle(), physical_layout: self.physical_layout.clone().recycle(), logical_layout: take(&mut self.logical_layout).recycle(), + debug_strings: take(&mut self.debug_strings).recycle(), debugs: take(&mut self.debugs).recycle(), annotations: take(&mut self.annotations).recycle(), lookup_type: take(&mut self.lookup_type).recycle(), @@ -173,9 +180,9 @@ impl Writer { global_variables: take(&mut self.global_variables).recycle(), saved_cached: take(&mut self.saved_cached).recycle(), temp_list: take(&mut self.temp_list).recycle(), - ray_get_candidate_intersection_function: None, - ray_get_committed_intersection_function: None, + ray_query_functions: take(&mut self.ray_query_functions).recycle(), io_f16_polyfills: take(&mut self.io_f16_polyfills).recycle(), + debug_printf: None, }; *self = fresh; @@ -1022,6 +1029,7 @@ impl Writer { expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), + ray_query_tracker_expr: crate::FastHashMap::default(), }; // fill up the pre-emitted and const expressions @@ -1063,6 +1071,34 @@ impl Writer { .function .variables .insert(handle, LocalVariable { id, instruction }); + + if let crate::TypeInner::RayQuery { .. } = ir_module.types[variable.ty].inner { + // Don't refactor this into a struct: Although spirv itself allows opaque types in structs, + // the vulkan environment for spirv does not. Putting ray queries into structs can cause + // confusing bugs. + let u32_type_id = context.writer.get_u32_type_id(); + let ptr_u32_type_id = context + .writer + .get_pointer_type_id(u32_type_id, spirv::StorageClass::Function); + let tracker_id = context.gen_id(); + let tracker_init_id = context + .writer + .get_constant_scalar(crate::Literal::U32(super::RayQueryPoint::empty().bits())); + let tracker_instruction = Instruction::variable( + ptr_u32_type_id, + tracker_id, + spirv::StorageClass::Function, + Some(tracker_init_id), + ); + + context.function.ray_query_tracker_variables.insert( + handle, + LocalVariable { + id: tracker_id, + instruction: tracker_instruction, + }, + ); + } } for (handle, expr) in ir_function.expressions.iter() { @@ -2634,6 +2670,10 @@ impl Writer { Instruction::memory_model(addressing_model, memory_model) .to_words(&mut self.logical_layout.memory_model); + for debug_string in self.debug_strings.iter() { + debug_string.to_words(&mut self.logical_layout.debugs); + } + if self.flags.contains(WriterFlags::DEBUG) { for debug in self.debugs.iter() { debug.to_words(&mut self.logical_layout.debugs); @@ -2693,6 +2733,41 @@ impl Writer { pub(super) fn needs_f16_polyfill(&self, ty_inner: &crate::TypeInner) -> bool { self.io_f16_polyfills.needs_polyfill(ty_inner) } + + #[allow(dead_code)] + pub(super) fn write_debug_printf( + &mut self, + block: &mut Block, + string: &str, + format_params: &[Word], + ) { + if self.debug_printf.is_none() { + self.use_extension("SPV_KHR_non_semantic_info"); + let import_id = self.id_gen.next(); + Instruction::ext_inst_import(import_id, "NonSemantic.DebugPrintf") + .to_words(&mut self.logical_layout.ext_inst_imports); + self.debug_printf = Some(import_id) + } + + let import_id = self.debug_printf.unwrap(); + + let string_id = self.id_gen.next(); + self.debug_strings + .push(Instruction::string(string, string_id)); + + let mut operands = Vec::with_capacity(1 + format_params.len()); + operands.push(string_id); + operands.extend(format_params.iter()); + + let print_id = self.id_gen.next(); + block.body.push(Instruction::ext_inst( + import_id, + 1, + self.void_type, + print_id, + &operands, + )); + } } #[test] diff --git a/naga/tests/in/wgsl/ray-query-no-init-tracking.toml b/naga/tests/in/wgsl/ray-query-no-init-tracking.toml new file mode 100644 index 00000000000..e2602b7b4d2 --- /dev/null +++ b/naga/tests/in/wgsl/ray-query-no-init-tracking.toml @@ -0,0 +1,19 @@ +god_mode = true +targets = "SPIRV | METAL | HLSL" + +[msl] +fake_missing_bindings = true +lang_version = [2, 4] +spirv_cross_compatibility = false +zero_initialize_workgroup_memory = false + +[hlsl] +shader_model = "V6_5" +fake_missing_bindings = true +zero_initialize_workgroup_memory = true +# Not yet implemented +# ray_query_initialization_tracking = false + +[spv] +version = [1, 4] +ray_query_initialization_tracking = false diff --git a/naga/tests/in/wgsl/ray-query-no-init-tracking.wgsl b/naga/tests/in/wgsl/ray-query-no-init-tracking.wgsl new file mode 100644 index 00000000000..e8fabb0208c --- /dev/null +++ b/naga/tests/in/wgsl/ray-query-no-init-tracking.wgsl @@ -0,0 +1,97 @@ +/* +let RAY_FLAG_NONE = 0x00u; +let RAY_FLAG_FORCE_OPAQUE = 0x01u; +let RAY_FLAG_FORCE_NO_OPAQUE = 0x02u; +let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; +let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; +let RAY_FLAG_CULL_BACK_FACING = 0x10u; +let RAY_FLAG_CULL_FRONT_FACING = 0x20u; +let RAY_FLAG_CULL_OPAQUE = 0x40u; +let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; +let RAY_FLAG_SKIP_TRIANGLES = 0x100u; +let RAY_FLAG_SKIP_AABBS = 0x200u; + +let RAY_QUERY_INTERSECTION_NONE = 0u; +let RAY_QUERY_INTERSECTION_TRIANGLE = 1u; +let RAY_QUERY_INTERSECTION_GENERATED = 2u; +let RAY_QUERY_INTERSECTION_AABB = 3u; + +struct RayDesc { + flags: u32, + cull_mask: u32, + t_min: f32, + t_max: f32, + origin: vec3, + dir: vec3, +} + +struct RayIntersection { + kind: u32, + t: f32, + instance_custom_data: u32, + instance_index: u32, + sbt_record_offset: u32, + geometry_index: u32, + primitive_index: u32, + barycentrics: vec2, + front_face: bool, + object_to_world: mat4x3, + world_to_object: mat4x3, +} +*/ + +fn query_loop(pos: vec3, dir: vec3, acs: acceleration_structure) -> RayIntersection { + var rq: ray_query; + rayQueryInitialize(&rq, acs, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir)); + + while (rayQueryProceed(&rq)) {} + + return rayQueryGetCommittedIntersection(&rq); +} + +@group(0) @binding(0) +var acc_struct: acceleration_structure; + +struct Output { + visible: u32, + normal: vec3, +} + +@group(0) @binding(1) +var output: Output; + +fn get_torus_normal(world_point: vec3, intersection: RayIntersection) -> vec3 { + let local_point = intersection.world_to_object * vec4(world_point, 1.0); + let point_on_guiding_line = normalize(local_point.xy) * 2.4; + let world_point_on_guiding_line = intersection.object_to_world * vec4(point_on_guiding_line, 0.0, 1.0); + return normalize(world_point - world_point_on_guiding_line); +} + + + +@compute @workgroup_size(1) +fn main() { + let pos = vec3(0.0); + let dir = vec3(0.0, 1.0, 0.0); + let intersection = query_loop(pos, dir, acc_struct); + + output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); + output.normal = get_torus_normal(dir * intersection.t, intersection); +} + +@compute @workgroup_size(1) +fn main_candidate() { + let pos = vec3(0.0); + let dir = vec3(0.0, 1.0, 0.0); + + var rq: ray_query; + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir)); + let intersection = rayQueryGetCandidateIntersection(&rq); + if (intersection.kind == RAY_QUERY_INTERSECTION_AABB) { + rayQueryGenerateIntersection(&rq, 10.0); + } else if (intersection.kind == RAY_QUERY_INTERSECTION_TRIANGLE) { + rayQueryConfirmIntersection(&rq); + } else { + rayQueryTerminate(&rq); + } +} diff --git a/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl new file mode 100644 index 00000000000..68be09b6b01 --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.hlsl @@ -0,0 +1,165 @@ +struct RayIntersection { + uint kind; + float t; + uint instance_custom_data; + uint instance_index; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + float2 barycentrics; + bool front_face; + int _pad9_0; + int _pad9_1; + row_major float4x3 object_to_world; + int _pad10_0; + row_major float4x3 world_to_object; + int _end_pad_0; +}; + +struct RayDesc_ { + uint flags; + uint cull_mask; + float tmin; + float tmax; + float3 origin; + int _pad5_0; + float3 dir; + int _end_pad_0; +}; + +struct Output { + uint visible; + int _pad1_0; + int _pad1_1; + int _pad1_2; + float3 normal; + int _end_pad_0; +}; + +RayDesc RayDescFromRayDesc_(RayDesc_ arg0) { + RayDesc ret = (RayDesc)0; + ret.Origin = arg0.origin; + ret.TMin = arg0.tmin; + ret.Direction = arg0.dir; + ret.TMax = arg0.tmax; + return ret; +} + +RaytracingAccelerationStructure acc_struct : register(t0); +RWByteAddressBuffer output : register(u1); + +RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 arg4, float3 arg5) { + RayDesc_ ret = (RayDesc_)0; + ret.flags = arg0; + ret.cull_mask = arg1; + ret.tmin = arg2; + ret.tmax = arg3; + ret.origin = arg4; + ret.dir = arg5; + return ret; +} + +RayIntersection GetCommittedIntersection(RayQuery rq) { + RayIntersection ret = (RayIntersection)0; + ret.kind = rq.CommittedStatus(); + if( rq.CommittedStatus() == COMMITTED_NOTHING) {} else { + ret.t = rq.CommittedRayT(); + ret.instance_custom_data = rq.CommittedInstanceID(); + ret.instance_index = rq.CommittedInstanceIndex(); + ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CommittedGeometryIndex(); + ret.primitive_index = rq.CommittedPrimitiveIndex(); + if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) { + ret.barycentrics = rq.CommittedTriangleBarycentrics(); + ret.front_face = rq.CommittedTriangleFrontFace(); + } + ret.object_to_world = rq.CommittedObjectToWorld4x3(); + ret.world_to_object = rq.CommittedWorldToObject4x3(); + } + return ret; +} + +RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructure acs) +{ + RayQuery rq_1; + + rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir))); + uint2 loop_bound = uint2(4294967295u, 4294967295u); + while(true) { + if (all(loop_bound == uint2(0u, 0u))) { break; } + loop_bound -= uint2(loop_bound.y == 0u, 1u); + const bool _e9 = rq_1.Proceed(); + if (_e9) { + } else { + break; + } + { + } + } + const RayIntersection rayintersection = GetCommittedIntersection(rq_1); + return rayintersection; +} + +float3 get_torus_normal(float3 world_point, RayIntersection intersection) +{ + float3 local_point = mul(float4(world_point, 1.0), intersection.world_to_object); + float2 point_on_guiding_line = (normalize(local_point.xy) * 2.4); + float3 world_point_on_guiding_line = mul(float4(point_on_guiding_line, 0.0, 1.0), intersection.object_to_world); + return normalize((world_point - world_point_on_guiding_line)); +} + +[numthreads(1, 1, 1)] +void main() +{ + float3 pos_1 = (0.0).xxx; + float3 dir_1 = float3(0.0, 1.0, 0.0); + const RayIntersection _e7 = query_loop(pos_1, dir_1, acc_struct); + output.Store(0, asuint(uint((_e7.kind == 0u)))); + const float3 _e18 = get_torus_normal((dir_1 * _e7.t), _e7); + output.Store3(16, asuint(_e18)); + return; +} + +RayIntersection GetCandidateIntersection(RayQuery rq) { + RayIntersection ret = (RayIntersection)0; + CANDIDATE_TYPE kind = rq.CandidateType(); + if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + ret.kind = 1; + ret.t = rq.CandidateTriangleRayT(); + ret.barycentrics = rq.CandidateTriangleBarycentrics(); + ret.front_face = rq.CandidateTriangleFrontFace(); + } else { + ret.kind = 3; + } + ret.instance_custom_data = rq.CandidateInstanceID(); + ret.instance_index = rq.CandidateInstanceIndex(); + ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CandidateGeometryIndex(); + ret.primitive_index = rq.CandidatePrimitiveIndex(); + ret.object_to_world = rq.CandidateObjectToWorld4x3(); + ret.world_to_object = rq.CandidateWorldToObject4x3(); + return ret; +} + +[numthreads(1, 1, 1)] +void main_candidate() +{ + RayQuery rq; + + float3 pos_2 = (0.0).xxx; + float3 dir_2 = float3(0.0, 1.0, 0.0); + rq.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2))); + RayIntersection intersection_1 = GetCandidateIntersection(rq); + if ((intersection_1.kind == 3u)) { + rq.CommitProceduralPrimitiveHit(10.0); + return; + } else { + if ((intersection_1.kind == 1u)) { + rq.CommitNonOpaqueTriangleHit(); + return; + } else { + rq.Abort(); + return; + } + } +} diff --git a/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.ron b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.ron new file mode 100644 index 00000000000..a31e1db125a --- /dev/null +++ b/naga/tests/out/hlsl/wgsl-ray-query-no-init-tracking.ron @@ -0,0 +1,16 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_5", + ), + ( + entry_point:"main_candidate", + target_profile:"cs_6_5", + ), + ], +) diff --git a/naga/tests/out/msl/wgsl-ray-query-no-init-tracking.msl b/naga/tests/out/msl/wgsl-ray-query-no-init-tracking.msl new file mode 100644 index 00000000000..55840c10920 --- /dev/null +++ b/naga/tests/out/msl/wgsl-ray-query-no-init-tracking.msl @@ -0,0 +1,116 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct RayIntersection { + uint kind; + float t; + uint instance_custom_data; + uint instance_index; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + metal::float2 barycentrics; + bool front_face; + char _pad9[11]; + metal::float4x3 object_to_world; + metal::float4x3 world_to_object; +}; +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +struct Output { + uint visible; + char _pad1[12]; + metal::float3 normal; +}; + +RayIntersection query_loop( + metal::float3 pos, + metal::float3 dir, + metal::raytracing::instance_acceleration_structure acs +) { + _RayQuery rq_1 = {}; + RayDesc _e8 = RayDesc {4u, 255u, 0.1, 100.0, pos, dir}; + rq_1.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq_1.intersector.set_opacity_cull_mode((_e8.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e8.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq_1.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq_1.intersector.accept_any_intersection((_e8.flags & 4) != 0); + rq_1.intersection = rq_1.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask); rq_1.ready = true; + uint2 loop_bound = uint2(4294967295u); + while(true) { + if (metal::all(loop_bound == uint2(0u))) { break; } + loop_bound -= uint2(loop_bound.y == 0u, 1u); + bool _e9 = rq_1.ready; + if (_e9) { + } else { + break; + } + } + return RayIntersection {_map_intersection_type(rq_1.intersection.type), rq_1.intersection.distance, rq_1.intersection.user_instance_id, rq_1.intersection.instance_id, {}, rq_1.intersection.geometry_id, rq_1.intersection.primitive_id, rq_1.intersection.triangle_barycentric_coord, rq_1.intersection.triangle_front_facing, {}, rq_1.intersection.object_to_world_transform, rq_1.intersection.world_to_object_transform}; +} + +metal::float3 get_torus_normal( + metal::float3 world_point, + RayIntersection intersection +) { + metal::float3 local_point = intersection.world_to_object * metal::float4(world_point, 1.0); + metal::float2 point_on_guiding_line = metal::normalize(local_point.xy) * 2.4; + metal::float3 world_point_on_guiding_line = intersection.object_to_world * metal::float4(point_on_guiding_line, 0.0, 1.0); + return metal::normalize(world_point - world_point_on_guiding_line); +} + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +, device Output& output [[user(fake0)]] +) { + metal::float3 pos_1 = metal::float3(0.0); + metal::float3 dir_1 = metal::float3(0.0, 1.0, 0.0); + RayIntersection _e7 = query_loop(pos_1, dir_1, acc_struct); + output.visible = static_cast(_e7.kind == 0u); + metal::float3 _e18 = get_torus_normal(dir_1 * _e7.t, _e7); + output.normal = _e18; + return; +} + + +kernel void main_candidate( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +) { + _RayQuery rq = {}; + metal::float3 pos_2 = metal::float3(0.0); + metal::float3 dir_2 = metal::float3(0.0, 1.0, 0.0); + RayDesc _e12 = RayDesc {4u, 255u, 0.1, 100.0, pos_2, dir_2}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((_e12.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true; + RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; + if (intersection_1.kind == 3u) { + return; + } else { + if (intersection_1.kind == 1u) { + return; + } else { + rq.ready = false; + return; + } + } +} diff --git a/naga/tests/out/spv/wgsl-aliased-ray-query.spvasm b/naga/tests/out/spv/wgsl-aliased-ray-query.spvasm index b095e8b8e83..05428ce7be4 100644 --- a/naga/tests/out/spv/wgsl-aliased-ray-query.spvasm +++ b/naga/tests/out/spv/wgsl-aliased-ray-query.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 102 +; Bound: 226 OpCapability Shader OpCapability RayQueryKHR OpExtension "SPV_KHR_ray_query" @@ -59,109 +59,259 @@ OpDecorate %13 Binding 0 %29 = OpConstant %5 10 %30 = OpConstant %7 1 %32 = OpTypePointer Function %3 -%40 = OpTypePointer Function %12 -%41 = OpTypePointer Function %7 -%42 = OpTypePointer Function %11 -%43 = OpTypePointer Function %9 -%44 = OpTypePointer Function %10 -%45 = OpTypePointer Function %5 -%46 = OpTypeFunction %12 %32 -%48 = OpConstantNull %12 -%52 = OpConstant %7 0 -%67 = OpConstant %7 2 -%71 = OpConstant %7 5 -%73 = OpConstant %7 6 -%75 = OpConstant %7 9 -%77 = OpConstant %7 10 -%86 = OpConstant %7 7 -%88 = OpConstant %7 8 -%47 = OpFunction %12 None %46 -%49 = OpFunctionParameter %32 -%50 = OpLabel -%51 = OpVariable %40 Function %48 -%53 = OpRayQueryGetIntersectionTypeKHR %7 %49 %52 -%54 = OpIEqual %10 %53 %52 -%55 = OpSelect %7 %54 %30 %28 -%56 = OpAccessChain %41 %51 %52 -OpStore %56 %55 -%57 = OpINotEqual %10 %55 %52 -OpSelectionMerge %59 None -OpBranchConditional %57 %58 %59 -%58 = OpLabel -%60 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %7 %49 %52 -%61 = OpRayQueryGetIntersectionInstanceIdKHR %7 %49 %52 -%62 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %7 %49 %52 -%63 = OpRayQueryGetIntersectionGeometryIndexKHR %7 %49 %52 -%64 = OpRayQueryGetIntersectionPrimitiveIndexKHR %7 %49 %52 -%65 = OpRayQueryGetIntersectionObjectToWorldKHR %11 %49 %52 -%66 = OpRayQueryGetIntersectionWorldToObjectKHR %11 %49 %52 -%68 = OpAccessChain %41 %51 %67 -OpStore %68 %60 -%69 = OpAccessChain %41 %51 %28 -OpStore %69 %61 -%70 = OpAccessChain %41 %51 %23 -OpStore %70 %62 -%72 = OpAccessChain %41 %51 %71 -OpStore %72 %63 -%74 = OpAccessChain %41 %51 %73 -OpStore %74 %64 -%76 = OpAccessChain %42 %51 %75 -OpStore %76 %65 -%78 = OpAccessChain %42 %51 %77 -OpStore %78 %66 -%79 = OpIEqual %10 %55 %30 -OpSelectionMerge %81 None -OpBranchConditional %57 %80 %81 -%80 = OpLabel -%82 = OpRayQueryGetIntersectionTKHR %5 %49 %52 -%83 = OpAccessChain %45 %51 %30 -OpStore %83 %82 -%84 = OpRayQueryGetIntersectionBarycentricsKHR %9 %49 %52 -%85 = OpRayQueryGetIntersectionFrontFaceKHR %10 %49 %52 -%87 = OpAccessChain %43 %51 %86 -OpStore %87 %84 -%89 = OpAccessChain %44 %51 %88 -OpStore %89 %85 -OpBranch %81 -%81 = OpLabel -OpBranch %59 -%59 = OpLabel -%90 = OpLoad %12 %51 -OpReturnValue %90 +%33 = OpTypePointer Function %7 +%35 = OpConstant %7 0 +%37 = OpTypeVector %10 3 +%38 = OpTypeFunction %2 %32 %4 %8 %33 +%65 = OpConstant %7 256 +%68 = OpConstant %7 512 +%73 = OpConstant %7 16 +%76 = OpConstant %7 32 +%87 = OpConstant %7 2 +%90 = OpConstant %7 64 +%93 = OpConstant %7 128 +%118 = OpTypePointer Function %12 +%119 = OpTypePointer Function %11 +%120 = OpTypePointer Function %9 +%121 = OpTypePointer Function %10 +%122 = OpTypePointer Function %5 +%123 = OpTypeFunction %12 %32 %33 +%125 = OpConstantNull %12 +%156 = OpConstant %7 5 +%158 = OpConstant %7 6 +%160 = OpConstant %7 9 +%162 = OpConstant %7 10 +%171 = OpConstant %7 7 +%173 = OpConstant %7 8 +%182 = OpTypeFunction %2 %32 %33 %5 +%207 = OpTypeFunction %2 %32 %33 +%39 = OpFunction %2 None %38 +%40 = OpFunctionParameter %32 +%41 = OpFunctionParameter %4 +%42 = OpFunctionParameter %8 +%43 = OpFunctionParameter %33 +%44 = OpLabel +%45 = OpCompositeExtract %7 %42 0 +%46 = OpCompositeExtract %7 %42 1 +%47 = OpCompositeExtract %5 %42 2 +%48 = OpCompositeExtract %5 %42 3 +%49 = OpCompositeExtract %6 %42 4 +%50 = OpCompositeExtract %6 %42 5 +%51 = OpFOrdLessThanEqual %10 %47 %48 +%52 = OpFOrdGreaterThanEqual %10 %47 %19 +%53 = OpIsInf %37 %49 +%54 = OpAny %10 %53 +%55 = OpIsNan %37 %49 +%56 = OpAny %10 %55 +%57 = OpLogicalOr %10 %56 %54 +%58 = OpLogicalNot %10 %57 +%59 = OpIsInf %37 %50 +%60 = OpAny %10 %59 +%61 = OpIsNan %37 %50 +%62 = OpAny %10 %61 +%63 = OpLogicalOr %10 %62 %60 +%64 = OpLogicalNot %10 %63 +%66 = OpBitwiseAnd %7 %45 %65 +%67 = OpINotEqual %10 %66 %35 +%69 = OpBitwiseAnd %7 %45 %68 +%70 = OpINotEqual %10 %69 %35 +%71 = OpLogicalAnd %10 %70 %67 +%72 = OpLogicalNot %10 %71 +%74 = OpBitwiseAnd %7 %45 %73 +%75 = OpINotEqual %10 %74 %35 +%77 = OpBitwiseAnd %7 %45 %76 +%78 = OpINotEqual %10 %77 %35 +%79 = OpLogicalAnd %10 %78 %67 +%80 = OpLogicalAnd %10 %78 %75 +%81 = OpLogicalAnd %10 %75 %67 +%82 = OpLogicalOr %10 %81 %79 +%83 = OpLogicalOr %10 %82 %80 +%84 = OpLogicalNot %10 %83 +%85 = OpBitwiseAnd %7 %45 %30 +%86 = OpINotEqual %10 %85 %35 +%88 = OpBitwiseAnd %7 %45 %87 +%89 = OpINotEqual %10 %88 %35 +%91 = OpBitwiseAnd %7 %45 %90 +%92 = OpINotEqual %10 %91 %35 +%94 = OpBitwiseAnd %7 %45 %93 +%95 = OpINotEqual %10 %94 %35 +%96 = OpLogicalAnd %10 %95 %86 +%97 = OpLogicalAnd %10 %95 %89 +%98 = OpLogicalAnd %10 %95 %92 +%99 = OpLogicalAnd %10 %92 %86 +%100 = OpLogicalAnd %10 %92 %89 +%101 = OpLogicalAnd %10 %89 %86 +%102 = OpLogicalOr %10 %101 %96 +%103 = OpLogicalOr %10 %102 %97 +%104 = OpLogicalOr %10 %103 %98 +%105 = OpLogicalOr %10 %104 %99 +%106 = OpLogicalOr %10 %105 %100 +%107 = OpLogicalNot %10 %106 +%108 = OpLogicalAnd %10 %51 %52 +%109 = OpLogicalAnd %10 %58 %64 +%110 = OpLogicalAnd %10 %72 %84 +%111 = OpLogicalAnd %10 %110 %107 +%112 = OpLogicalAnd %10 %108 %109 +%113 = OpLogicalAnd %10 %112 %111 +OpSelectionMerge %114 None +OpBranchConditional %113 %116 %115 +%116 = OpLabel +OpRayQueryInitializeKHR %40 %41 %45 %46 %49 %47 %50 %48 +OpStore %43 %30 +OpBranch %114 +%115 = OpLabel +OpBranch %114 +%114 = OpLabel +OpReturn +OpFunctionEnd +%124 = OpFunction %12 None %123 +%126 = OpFunctionParameter %32 +%127 = OpFunctionParameter %33 +%128 = OpLabel +%129 = OpVariable %118 Function %125 +%130 = OpLoad %7 %127 +%131 = OpBitwiseAnd %7 %130 %87 +%132 = OpINotEqual %10 %131 %35 +%133 = OpBitwiseAnd %7 %130 %23 +%134 = OpINotEqual %10 %133 %35 +%135 = OpLogicalNot %10 %134 +%136 = OpLogicalAnd %10 %135 %132 +OpSelectionMerge %138 None +OpBranchConditional %136 %137 %138 +%137 = OpLabel +%139 = OpRayQueryGetIntersectionTypeKHR %7 %126 %35 +%140 = OpIEqual %10 %139 %35 +%141 = OpSelect %7 %140 %30 %28 +%142 = OpAccessChain %33 %129 %35 +OpStore %142 %141 +%143 = OpINotEqual %10 %141 %35 +OpSelectionMerge %145 None +OpBranchConditional %143 %144 %145 +%144 = OpLabel +%146 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %7 %126 %35 +%147 = OpRayQueryGetIntersectionInstanceIdKHR %7 %126 %35 +%148 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %7 %126 %35 +%149 = OpRayQueryGetIntersectionGeometryIndexKHR %7 %126 %35 +%150 = OpRayQueryGetIntersectionPrimitiveIndexKHR %7 %126 %35 +%151 = OpRayQueryGetIntersectionObjectToWorldKHR %11 %126 %35 +%152 = OpRayQueryGetIntersectionWorldToObjectKHR %11 %126 %35 +%153 = OpAccessChain %33 %129 %87 +OpStore %153 %146 +%154 = OpAccessChain %33 %129 %28 +OpStore %154 %147 +%155 = OpAccessChain %33 %129 %23 +OpStore %155 %148 +%157 = OpAccessChain %33 %129 %156 +OpStore %157 %149 +%159 = OpAccessChain %33 %129 %158 +OpStore %159 %150 +%161 = OpAccessChain %119 %129 %160 +OpStore %161 %151 +%163 = OpAccessChain %119 %129 %162 +OpStore %163 %152 +%164 = OpIEqual %10 %141 %30 +OpSelectionMerge %166 None +OpBranchConditional %143 %165 %166 +%165 = OpLabel +%167 = OpRayQueryGetIntersectionTKHR %5 %126 %35 +%168 = OpAccessChain %122 %129 %30 +OpStore %168 %167 +%169 = OpRayQueryGetIntersectionBarycentricsKHR %9 %126 %35 +%170 = OpRayQueryGetIntersectionFrontFaceKHR %10 %126 %35 +%172 = OpAccessChain %120 %129 %171 +OpStore %172 %169 +%174 = OpAccessChain %121 %129 %173 +OpStore %174 %170 +OpBranch %166 +%166 = OpLabel +OpBranch %145 +%145 = OpLabel +OpBranch %138 +%138 = OpLabel +%175 = OpLoad %12 %129 +OpReturnValue %175 +OpFunctionEnd +%183 = OpFunction %2 None %182 +%184 = OpFunctionParameter %32 +%185 = OpFunctionParameter %33 +%186 = OpFunctionParameter %5 +%187 = OpLabel +%190 = OpLoad %7 %185 +%191 = OpBitwiseAnd %7 %190 %87 +%192 = OpINotEqual %10 %191 %35 +%193 = OpBitwiseAnd %7 %190 %23 +%194 = OpINotEqual %10 %193 %35 +%195 = OpLogicalNot %10 %194 +%196 = OpLogicalAnd %10 %195 %192 +OpSelectionMerge %189 None +OpBranchConditional %196 %188 %189 +%188 = OpLabel +%197 = OpRayQueryGetIntersectionTypeKHR %7 %184 %35 +%198 = OpIEqual %10 %197 %30 +OpSelectionMerge %200 None +OpBranchConditional %198 %199 %200 +%199 = OpLabel +OpRayQueryGenerateIntersectionKHR %184 %186 +OpBranch %200 +%200 = OpLabel +OpBranch %189 +%189 = OpLabel +OpReturn +OpFunctionEnd +%208 = OpFunction %2 None %207 +%209 = OpFunctionParameter %32 +%210 = OpFunctionParameter %33 +%211 = OpLabel +%214 = OpLoad %7 %210 +%215 = OpBitwiseAnd %7 %214 %87 +%216 = OpINotEqual %10 %215 %35 +%217 = OpBitwiseAnd %7 %214 %23 +%218 = OpINotEqual %10 %217 %35 +%219 = OpLogicalNot %10 %218 +%220 = OpLogicalAnd %10 %219 %216 +OpSelectionMerge %213 None +OpBranchConditional %220 %212 %213 +%212 = OpLabel +%221 = OpRayQueryGetIntersectionTypeKHR %7 %209 %35 +%222 = OpIEqual %10 %221 %35 +OpSelectionMerge %224 None +OpBranchConditional %222 %223 %224 +%223 = OpLabel +OpRayQueryConfirmIntersectionKHR %209 +OpBranch %224 +%224 = OpLabel +OpBranch %213 +%213 = OpLabel +OpReturn OpFunctionEnd %16 = OpFunction %2 None %17 %15 = OpLabel %31 = OpVariable %32 Function +%34 = OpVariable %33 Function %35 %18 = OpLoad %4 %13 -OpBranch %33 -%33 = OpLabel -%34 = OpCompositeExtract %7 %27 0 -%35 = OpCompositeExtract %7 %27 1 -%36 = OpCompositeExtract %5 %27 2 -%37 = OpCompositeExtract %5 %27 3 -%38 = OpCompositeExtract %6 %27 4 -%39 = OpCompositeExtract %6 %27 5 -OpRayQueryInitializeKHR %31 %18 %34 %35 %38 %36 %39 %37 -%91 = OpFunctionCall %12 %47 %31 -%92 = OpCompositeExtract %7 %91 0 -%93 = OpIEqual %10 %92 %28 -OpSelectionMerge %94 None -OpBranchConditional %93 %95 %96 -%95 = OpLabel -OpRayQueryGenerateIntersectionKHR %31 %29 +OpBranch %36 +%36 = OpLabel +%117 = OpFunctionCall %2 %39 %31 %18 %27 %34 +%176 = OpFunctionCall %12 %124 %31 %34 +%177 = OpCompositeExtract %7 %176 0 +%178 = OpIEqual %10 %177 %28 +OpSelectionMerge %179 None +OpBranchConditional %178 %180 %181 +%180 = OpLabel +%201 = OpFunctionCall %2 %183 %31 %34 %29 OpReturn -%96 = OpLabel -%97 = OpCompositeExtract %7 %91 0 -%98 = OpIEqual %10 %97 %30 -OpSelectionMerge %99 None -OpBranchConditional %98 %100 %101 -%100 = OpLabel -OpRayQueryConfirmIntersectionKHR %31 +%181 = OpLabel +%202 = OpCompositeExtract %7 %176 0 +%203 = OpIEqual %10 %202 %30 +OpSelectionMerge %204 None +OpBranchConditional %203 %205 %206 +%205 = OpLabel +%225 = OpFunctionCall %2 %208 %31 %34 OpReturn -%101 = OpLabel +%206 = OpLabel OpReturn -%99 = OpLabel -OpBranch %94 -%94 = OpLabel +%204 = OpLabel +OpBranch %179 +%179 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-overrides-ray-query.main.spvasm b/naga/tests/out/spv/wgsl-overrides-ray-query.main.spvasm index 34a8df87711..c2377a681bf 100644 --- a/naga/tests/out/spv/wgsl-overrides-ray-query.main.spvasm +++ b/naga/tests/out/spv/wgsl-overrides-ray-query.main.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 65 +; Bound: 161 OpCapability Shader OpCapability RayQueryKHR OpExtension "SPV_KHR_ray_query" @@ -40,61 +40,171 @@ OpDecorate %10 Binding 0 %25 = OpConstantComposite %7 %22 %23 %24 %26 = OpConstantComposite %8 %16 %17 %18 %19 %21 %25 %28 = OpTypePointer Function %5 -%40 = OpTypeVector %6 2 -%41 = OpTypePointer Function %40 -%42 = OpTypeBool -%43 = OpTypeVector %42 2 -%44 = OpConstant %6 0 -%45 = OpConstantComposite %40 %44 %44 -%46 = OpConstant %6 1 -%47 = OpConstant %6 4294967295 -%48 = OpConstantComposite %40 %47 %47 +%29 = OpTypePointer Function %6 +%31 = OpConstant %6 0 +%33 = OpTypeBool +%34 = OpTypeVector %33 3 +%35 = OpTypeFunction %2 %28 %4 %8 %29 +%50 = OpConstant %3 0 +%63 = OpConstant %6 256 +%66 = OpConstant %6 512 +%71 = OpConstant %6 16 +%74 = OpConstant %6 32 +%83 = OpConstant %6 1 +%86 = OpConstant %6 2 +%89 = OpConstant %6 64 +%92 = OpConstant %6 128 +%121 = OpTypeVector %6 2 +%122 = OpTypePointer Function %121 +%123 = OpTypeVector %33 2 +%124 = OpConstantComposite %121 %31 %31 +%125 = OpConstant %6 4294967295 +%126 = OpConstantComposite %121 %125 %125 +%139 = OpTypePointer Function %33 +%140 = OpTypeFunction %33 %28 %29 +%146 = OpConstantFalse %33 +%153 = OpConstant %6 6 +%36 = OpFunction %2 None %35 +%37 = OpFunctionParameter %28 +%38 = OpFunctionParameter %4 +%39 = OpFunctionParameter %8 +%40 = OpFunctionParameter %29 +%41 = OpLabel +%42 = OpCompositeExtract %6 %39 0 +%43 = OpCompositeExtract %6 %39 1 +%44 = OpCompositeExtract %3 %39 2 +%45 = OpCompositeExtract %3 %39 3 +%46 = OpCompositeExtract %7 %39 4 +%47 = OpCompositeExtract %7 %39 5 +%48 = OpFOrdLessThanEqual %33 %44 %45 +%49 = OpFOrdGreaterThanEqual %33 %44 %50 +%51 = OpIsInf %34 %46 +%52 = OpAny %33 %51 +%53 = OpIsNan %34 %46 +%54 = OpAny %33 %53 +%55 = OpLogicalOr %33 %54 %52 +%56 = OpLogicalNot %33 %55 +%57 = OpIsInf %34 %47 +%58 = OpAny %33 %57 +%59 = OpIsNan %34 %47 +%60 = OpAny %33 %59 +%61 = OpLogicalOr %33 %60 %58 +%62 = OpLogicalNot %33 %61 +%64 = OpBitwiseAnd %6 %42 %63 +%65 = OpINotEqual %33 %64 %31 +%67 = OpBitwiseAnd %6 %42 %66 +%68 = OpINotEqual %33 %67 %31 +%69 = OpLogicalAnd %33 %68 %65 +%70 = OpLogicalNot %33 %69 +%72 = OpBitwiseAnd %6 %42 %71 +%73 = OpINotEqual %33 %72 %31 +%75 = OpBitwiseAnd %6 %42 %74 +%76 = OpINotEqual %33 %75 %31 +%77 = OpLogicalAnd %33 %76 %65 +%78 = OpLogicalAnd %33 %76 %73 +%79 = OpLogicalAnd %33 %73 %65 +%80 = OpLogicalOr %33 %79 %77 +%81 = OpLogicalOr %33 %80 %78 +%82 = OpLogicalNot %33 %81 +%84 = OpBitwiseAnd %6 %42 %83 +%85 = OpINotEqual %33 %84 %31 +%87 = OpBitwiseAnd %6 %42 %86 +%88 = OpINotEqual %33 %87 %31 +%90 = OpBitwiseAnd %6 %42 %89 +%91 = OpINotEqual %33 %90 %31 +%93 = OpBitwiseAnd %6 %42 %92 +%94 = OpINotEqual %33 %93 %31 +%95 = OpLogicalAnd %33 %94 %85 +%96 = OpLogicalAnd %33 %94 %88 +%97 = OpLogicalAnd %33 %94 %91 +%98 = OpLogicalAnd %33 %91 %85 +%99 = OpLogicalAnd %33 %91 %88 +%100 = OpLogicalAnd %33 %88 %85 +%101 = OpLogicalOr %33 %100 %95 +%102 = OpLogicalOr %33 %101 %96 +%103 = OpLogicalOr %33 %102 %97 +%104 = OpLogicalOr %33 %103 %98 +%105 = OpLogicalOr %33 %104 %99 +%106 = OpLogicalNot %33 %105 +%107 = OpLogicalAnd %33 %48 %49 +%108 = OpLogicalAnd %33 %56 %62 +%109 = OpLogicalAnd %33 %70 %82 +%110 = OpLogicalAnd %33 %109 %106 +%111 = OpLogicalAnd %33 %107 %108 +%112 = OpLogicalAnd %33 %111 %110 +OpSelectionMerge %113 None +OpBranchConditional %112 %115 %114 +%115 = OpLabel +OpRayQueryInitializeKHR %37 %38 %42 %43 %46 %44 %47 %45 +OpStore %40 %83 +OpBranch %113 +%114 = OpLabel +OpBranch %113 +%113 = OpLabel +OpReturn +OpFunctionEnd +%141 = OpFunction %33 None %140 +%142 = OpFunctionParameter %28 +%143 = OpFunctionParameter %29 +%144 = OpLabel +%145 = OpVariable %139 Function %146 +%147 = OpLoad %6 %143 +%150 = OpBitwiseAnd %6 %147 %83 +%151 = OpINotEqual %33 %150 %31 +OpSelectionMerge %148 None +OpBranchConditional %151 %149 %148 +%149 = OpLabel +%152 = OpRayQueryProceedKHR %33 %142 +OpStore %145 %152 +%154 = OpSelect %6 %152 %86 %153 +%155 = OpBitwiseOr %6 %147 %154 +OpStore %143 %155 +OpBranch %148 +%148 = OpLabel +%156 = OpLoad %33 %145 +OpReturnValue %156 +OpFunctionEnd %13 = OpFunction %2 None %14 %12 = OpLabel %27 = OpVariable %28 Function -%49 = OpVariable %41 Function %48 +%30 = OpVariable %29 Function %31 +%127 = OpVariable %122 Function %126 %15 = OpLoad %4 %10 -OpBranch %29 -%29 = OpLabel -%30 = OpCompositeExtract %6 %26 0 -%31 = OpCompositeExtract %6 %26 1 -%32 = OpCompositeExtract %3 %26 2 -%33 = OpCompositeExtract %3 %26 3 -%34 = OpCompositeExtract %7 %26 4 -%35 = OpCompositeExtract %7 %26 5 -OpRayQueryInitializeKHR %27 %15 %30 %31 %34 %32 %35 %33 -OpBranch %36 -%36 = OpLabel -OpLoopMerge %37 %39 None -OpBranch %50 -%50 = OpLabel -%51 = OpLoad %40 %49 -%52 = OpIEqual %43 %45 %51 -%53 = OpAll %42 %52 -OpSelectionMerge %54 None -OpBranchConditional %53 %37 %54 -%54 = OpLabel -%55 = OpCompositeExtract %6 %51 1 -%56 = OpIEqual %42 %55 %44 -%57 = OpSelect %6 %56 %46 %44 -%58 = OpCompositeConstruct %40 %57 %46 -%59 = OpISub %40 %51 %58 -OpStore %49 %59 -OpBranch %38 -%38 = OpLabel -%60 = OpRayQueryProceedKHR %42 %27 -OpSelectionMerge %61 None -OpBranchConditional %60 %61 %62 -%62 = OpLabel -OpBranch %37 -%61 = OpLabel -OpBranch %63 -%63 = OpLabel -OpBranch %64 -%64 = OpLabel -OpBranch %39 -%39 = OpLabel -OpBranch %36 -%37 = OpLabel +OpBranch %32 +%32 = OpLabel +%116 = OpFunctionCall %2 %36 %27 %15 %26 %30 +OpBranch %117 +%117 = OpLabel +OpLoopMerge %118 %120 None +OpBranch %128 +%128 = OpLabel +%129 = OpLoad %121 %127 +%130 = OpIEqual %123 %124 %129 +%131 = OpAll %33 %130 +OpSelectionMerge %132 None +OpBranchConditional %131 %118 %132 +%132 = OpLabel +%133 = OpCompositeExtract %6 %129 1 +%134 = OpIEqual %33 %133 %31 +%135 = OpSelect %6 %134 %83 %31 +%136 = OpCompositeConstruct %121 %135 %83 +%137 = OpISub %121 %129 %136 +OpStore %127 %137 +OpBranch %119 +%119 = OpLabel +%138 = OpFunctionCall %33 %141 %27 %30 +OpSelectionMerge %157 None +OpBranchConditional %138 %157 %158 +%158 = OpLabel +OpBranch %118 +%157 = OpLabel +OpBranch %159 +%159 = OpLabel +OpBranch %160 +%160 = OpLabel +OpBranch %120 +%120 = OpLabel +OpBranch %117 +%118 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-ray-query-no-init-tracking.spvasm b/naga/tests/out/spv/wgsl-ray-query-no-init-tracking.spvasm new file mode 100644 index 00000000000..1e82468bf0a --- /dev/null +++ b/naga/tests/out/spv/wgsl-ray-query-no-init-tracking.spvasm @@ -0,0 +1,516 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 363 +OpCapability Shader +OpCapability RayQueryKHR +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %240 "main" %15 %17 +OpEntryPoint GLCompute %260 "main_candidate" %15 +OpExecutionMode %240 LocalSize 1 1 1 +OpExecutionMode %260 LocalSize 1 1 1 +OpMemberDecorate %10 0 Offset 0 +OpMemberDecorate %10 1 Offset 4 +OpMemberDecorate %10 2 Offset 8 +OpMemberDecorate %10 3 Offset 12 +OpMemberDecorate %10 4 Offset 16 +OpMemberDecorate %10 5 Offset 20 +OpMemberDecorate %10 6 Offset 24 +OpMemberDecorate %10 7 Offset 28 +OpMemberDecorate %10 8 Offset 36 +OpMemberDecorate %10 9 Offset 48 +OpMemberDecorate %10 9 ColMajor +OpMemberDecorate %10 9 MatrixStride 16 +OpMemberDecorate %10 10 Offset 112 +OpMemberDecorate %10 10 ColMajor +OpMemberDecorate %10 10 MatrixStride 16 +OpMemberDecorate %12 0 Offset 0 +OpMemberDecorate %12 1 Offset 4 +OpMemberDecorate %12 2 Offset 8 +OpMemberDecorate %12 3 Offset 12 +OpMemberDecorate %12 4 Offset 16 +OpMemberDecorate %12 5 Offset 32 +OpMemberDecorate %13 0 Offset 0 +OpMemberDecorate %13 1 Offset 16 +OpDecorate %15 DescriptorSet 0 +OpDecorate %15 Binding 0 +OpDecorate %17 DescriptorSet 0 +OpDecorate %17 Binding 1 +OpDecorate %18 Block +OpMemberDecorate %18 0 Offset 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeVector %3 3 +%5 = OpTypeAccelerationStructureKHR +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %3 2 +%8 = OpTypeBool +%9 = OpTypeMatrix %4 4 +%10 = OpTypeStruct %6 %3 %6 %6 %6 %6 %6 %7 %8 %9 %9 +%11 = OpTypeRayQueryKHR +%12 = OpTypeStruct %6 %6 %3 %3 %4 %4 +%13 = OpTypeStruct %6 %4 +%14 = OpTypeVector %3 4 +%16 = OpTypePointer UniformConstant %5 +%15 = OpVariable %16 UniformConstant +%18 = OpTypeStruct %13 +%19 = OpTypePointer StorageBuffer %18 +%17 = OpVariable %19 StorageBuffer +%26 = OpTypeFunction %10 %4 %4 %16 +%27 = OpConstant %6 4 +%28 = OpConstant %6 255 +%29 = OpConstant %3 0.1 +%30 = OpConstant %3 100 +%32 = OpTypePointer Function %11 +%33 = OpTypePointer Function %6 +%35 = OpConstant %6 0 +%38 = OpTypeVector %8 3 +%39 = OpTypeFunction %2 %32 %5 %12 %33 +%54 = OpConstant %3 0 +%67 = OpConstant %6 256 +%70 = OpConstant %6 512 +%75 = OpConstant %6 16 +%78 = OpConstant %6 32 +%87 = OpConstant %6 1 +%90 = OpConstant %6 2 +%93 = OpConstant %6 64 +%96 = OpConstant %6 128 +%125 = OpTypeVector %6 2 +%126 = OpTypePointer Function %125 +%127 = OpTypeVector %8 2 +%128 = OpConstantComposite %125 %35 %35 +%129 = OpConstant %6 4294967295 +%130 = OpConstantComposite %125 %129 %129 +%143 = OpTypePointer Function %8 +%144 = OpTypeFunction %8 %32 %33 +%150 = OpConstantFalse %8 +%157 = OpConstant %6 6 +%165 = OpTypePointer Function %10 +%166 = OpTypePointer Function %9 +%167 = OpTypePointer Function %7 +%168 = OpTypePointer Function %3 +%169 = OpTypeFunction %10 %32 %33 +%171 = OpConstantNull %10 +%197 = OpConstant %6 3 +%200 = OpConstant %6 5 +%203 = OpConstant %6 9 +%205 = OpConstant %6 10 +%214 = OpConstant %6 7 +%216 = OpConstant %6 8 +%224 = OpTypeFunction %4 %4 %10 +%225 = OpConstant %3 1 +%226 = OpConstant %3 2.4 +%241 = OpTypeFunction %2 +%243 = OpTypePointer StorageBuffer %13 +%245 = OpConstantComposite %4 %54 %54 %54 +%246 = OpConstantComposite %4 %54 %225 %54 +%249 = OpTypePointer StorageBuffer %6 +%254 = OpTypePointer StorageBuffer %4 +%262 = OpConstantComposite %12 %27 %28 %29 %30 %245 %246 +%263 = OpConstant %3 10 +%319 = OpTypeFunction %2 %32 %33 %3 +%344 = OpTypeFunction %2 %32 %33 +%40 = OpFunction %2 None %39 +%41 = OpFunctionParameter %32 +%42 = OpFunctionParameter %5 +%43 = OpFunctionParameter %12 +%44 = OpFunctionParameter %33 +%45 = OpLabel +%46 = OpCompositeExtract %6 %43 0 +%47 = OpCompositeExtract %6 %43 1 +%48 = OpCompositeExtract %3 %43 2 +%49 = OpCompositeExtract %3 %43 3 +%50 = OpCompositeExtract %4 %43 4 +%51 = OpCompositeExtract %4 %43 5 +%52 = OpFOrdLessThanEqual %8 %48 %49 +%53 = OpFOrdGreaterThanEqual %8 %48 %54 +%55 = OpIsInf %38 %50 +%56 = OpAny %8 %55 +%57 = OpIsNan %38 %50 +%58 = OpAny %8 %57 +%59 = OpLogicalOr %8 %58 %56 +%60 = OpLogicalNot %8 %59 +%61 = OpIsInf %38 %51 +%62 = OpAny %8 %61 +%63 = OpIsNan %38 %51 +%64 = OpAny %8 %63 +%65 = OpLogicalOr %8 %64 %62 +%66 = OpLogicalNot %8 %65 +%68 = OpBitwiseAnd %6 %46 %67 +%69 = OpINotEqual %8 %68 %35 +%71 = OpBitwiseAnd %6 %46 %70 +%72 = OpINotEqual %8 %71 %35 +%73 = OpLogicalAnd %8 %72 %69 +%74 = OpLogicalNot %8 %73 +%76 = OpBitwiseAnd %6 %46 %75 +%77 = OpINotEqual %8 %76 %35 +%79 = OpBitwiseAnd %6 %46 %78 +%80 = OpINotEqual %8 %79 %35 +%81 = OpLogicalAnd %8 %80 %69 +%82 = OpLogicalAnd %8 %80 %77 +%83 = OpLogicalAnd %8 %77 %69 +%84 = OpLogicalOr %8 %83 %81 +%85 = OpLogicalOr %8 %84 %82 +%86 = OpLogicalNot %8 %85 +%88 = OpBitwiseAnd %6 %46 %87 +%89 = OpINotEqual %8 %88 %35 +%91 = OpBitwiseAnd %6 %46 %90 +%92 = OpINotEqual %8 %91 %35 +%94 = OpBitwiseAnd %6 %46 %93 +%95 = OpINotEqual %8 %94 %35 +%97 = OpBitwiseAnd %6 %46 %96 +%98 = OpINotEqual %8 %97 %35 +%99 = OpLogicalAnd %8 %98 %89 +%100 = OpLogicalAnd %8 %98 %92 +%101 = OpLogicalAnd %8 %98 %95 +%102 = OpLogicalAnd %8 %95 %89 +%103 = OpLogicalAnd %8 %95 %92 +%104 = OpLogicalAnd %8 %92 %89 +%105 = OpLogicalOr %8 %104 %99 +%106 = OpLogicalOr %8 %105 %100 +%107 = OpLogicalOr %8 %106 %101 +%108 = OpLogicalOr %8 %107 %102 +%109 = OpLogicalOr %8 %108 %103 +%110 = OpLogicalNot %8 %109 +%111 = OpLogicalAnd %8 %52 %53 +%112 = OpLogicalAnd %8 %60 %66 +%113 = OpLogicalAnd %8 %74 %86 +%114 = OpLogicalAnd %8 %113 %110 +%115 = OpLogicalAnd %8 %111 %112 +%116 = OpLogicalAnd %8 %115 %114 +OpSelectionMerge %117 None +OpBranchConditional %116 %119 %118 +%119 = OpLabel +OpRayQueryInitializeKHR %41 %42 %46 %47 %50 %48 %51 %49 +OpStore %44 %87 +OpBranch %117 +%118 = OpLabel +OpBranch %117 +%117 = OpLabel +OpReturn +OpFunctionEnd +%145 = OpFunction %8 None %144 +%146 = OpFunctionParameter %32 +%147 = OpFunctionParameter %33 +%148 = OpLabel +%149 = OpVariable %143 Function %150 +%151 = OpLoad %6 %147 +%154 = OpBitwiseAnd %6 %151 %87 +%155 = OpINotEqual %8 %154 %35 +OpSelectionMerge %152 None +OpBranchConditional %155 %153 %152 +%153 = OpLabel +%156 = OpRayQueryProceedKHR %8 %146 +OpStore %149 %156 +%158 = OpSelect %6 %156 %90 %157 +%159 = OpBitwiseOr %6 %151 %158 +OpStore %147 %159 +OpBranch %152 +%152 = OpLabel +%160 = OpLoad %8 %149 +OpReturnValue %160 +OpFunctionEnd +%170 = OpFunction %10 None %169 +%172 = OpFunctionParameter %32 +%173 = OpFunctionParameter %33 +%174 = OpLabel +%175 = OpVariable %165 Function %171 +%176 = OpLoad %6 %173 +%177 = OpBitwiseAnd %6 %176 %90 +%178 = OpINotEqual %8 %177 %35 +%179 = OpBitwiseAnd %6 %176 %27 +%180 = OpINotEqual %8 %179 %35 +%181 = OpLogicalAnd %8 %180 %178 +OpSelectionMerge %183 None +OpBranchConditional %181 %182 %183 +%182 = OpLabel +%184 = OpRayQueryGetIntersectionTypeKHR %6 %172 %87 +%185 = OpAccessChain %33 %175 %35 +OpStore %185 %184 +%186 = OpINotEqual %8 %184 %35 +OpSelectionMerge %188 None +OpBranchConditional %186 %187 %188 +%187 = OpLabel +%189 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %172 %87 +%190 = OpRayQueryGetIntersectionInstanceIdKHR %6 %172 %87 +%191 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %172 %87 +%192 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %172 %87 +%193 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %172 %87 +%194 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %172 %87 +%195 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %172 %87 +%196 = OpAccessChain %33 %175 %90 +OpStore %196 %189 +%198 = OpAccessChain %33 %175 %197 +OpStore %198 %190 +%199 = OpAccessChain %33 %175 %27 +OpStore %199 %191 +%201 = OpAccessChain %33 %175 %200 +OpStore %201 %192 +%202 = OpAccessChain %33 %175 %157 +OpStore %202 %193 +%204 = OpAccessChain %166 %175 %203 +OpStore %204 %194 +%206 = OpAccessChain %166 %175 %205 +OpStore %206 %195 +%207 = OpIEqual %8 %184 %87 +%210 = OpRayQueryGetIntersectionTKHR %3 %172 %87 +%211 = OpAccessChain %168 %175 %87 +OpStore %211 %210 +OpSelectionMerge %209 None +OpBranchConditional %186 %208 %209 +%208 = OpLabel +%212 = OpRayQueryGetIntersectionBarycentricsKHR %7 %172 %87 +%213 = OpRayQueryGetIntersectionFrontFaceKHR %8 %172 %87 +%215 = OpAccessChain %167 %175 %214 +OpStore %215 %212 +%217 = OpAccessChain %143 %175 %216 +OpStore %217 %213 +OpBranch %209 +%209 = OpLabel +OpBranch %188 +%188 = OpLabel +OpBranch %183 +%183 = OpLabel +%218 = OpLoad %10 %175 +OpReturnValue %218 +OpFunctionEnd +%25 = OpFunction %10 None %26 +%21 = OpFunctionParameter %4 +%22 = OpFunctionParameter %4 +%23 = OpFunctionParameter %16 +%20 = OpLabel +%31 = OpVariable %32 Function +%34 = OpVariable %33 Function %35 +%131 = OpVariable %126 Function %130 +%24 = OpLoad %5 %23 +OpBranch %36 +%36 = OpLabel +%37 = OpCompositeConstruct %12 %27 %28 %29 %30 %21 %22 +%120 = OpFunctionCall %2 %40 %31 %24 %37 %34 +OpBranch %121 +%121 = OpLabel +OpLoopMerge %122 %124 None +OpBranch %132 +%132 = OpLabel +%133 = OpLoad %125 %131 +%134 = OpIEqual %127 %128 %133 +%135 = OpAll %8 %134 +OpSelectionMerge %136 None +OpBranchConditional %135 %122 %136 +%136 = OpLabel +%137 = OpCompositeExtract %6 %133 1 +%138 = OpIEqual %8 %137 %35 +%139 = OpSelect %6 %138 %87 %35 +%140 = OpCompositeConstruct %125 %139 %87 +%141 = OpISub %125 %133 %140 +OpStore %131 %141 +OpBranch %123 +%123 = OpLabel +%142 = OpFunctionCall %8 %145 %31 %34 +OpSelectionMerge %161 None +OpBranchConditional %142 %161 %162 +%162 = OpLabel +OpBranch %122 +%161 = OpLabel +OpBranch %163 +%163 = OpLabel +OpBranch %164 +%164 = OpLabel +OpBranch %124 +%124 = OpLabel +OpBranch %121 +%122 = OpLabel +%219 = OpFunctionCall %10 %170 %31 %34 +OpReturnValue %219 +OpFunctionEnd +%223 = OpFunction %4 None %224 +%221 = OpFunctionParameter %4 +%222 = OpFunctionParameter %10 +%220 = OpLabel +OpBranch %227 +%227 = OpLabel +%228 = OpCompositeExtract %9 %222 10 +%229 = OpCompositeConstruct %14 %221 %225 +%230 = OpMatrixTimesVector %4 %228 %229 +%231 = OpVectorShuffle %7 %230 %230 0 1 +%232 = OpExtInst %7 %1 Normalize %231 +%233 = OpVectorTimesScalar %7 %232 %226 +%234 = OpCompositeExtract %9 %222 9 +%235 = OpCompositeConstruct %14 %233 %54 %225 +%236 = OpMatrixTimesVector %4 %234 %235 +%237 = OpFSub %4 %221 %236 +%238 = OpExtInst %4 %1 Normalize %237 +OpReturnValue %238 +OpFunctionEnd +%240 = OpFunction %2 None %241 +%239 = OpLabel +%242 = OpLoad %5 %15 +%244 = OpAccessChain %243 %17 %35 +OpBranch %247 +%247 = OpLabel +%248 = OpFunctionCall %10 %25 %245 %246 %15 +%250 = OpCompositeExtract %6 %248 0 +%251 = OpIEqual %8 %250 %35 +%252 = OpSelect %6 %251 %87 %35 +%253 = OpAccessChain %249 %244 %35 +OpStore %253 %252 +%255 = OpCompositeExtract %3 %248 1 +%256 = OpVectorTimesScalar %4 %246 %255 +%257 = OpFunctionCall %4 %223 %256 %248 +%258 = OpAccessChain %254 %244 %87 +OpStore %258 %257 +OpReturn +OpFunctionEnd +%268 = OpFunction %10 None %169 +%269 = OpFunctionParameter %32 +%270 = OpFunctionParameter %33 +%271 = OpLabel +%272 = OpVariable %165 Function %171 +%273 = OpLoad %6 %270 +%274 = OpBitwiseAnd %6 %273 %90 +%275 = OpINotEqual %8 %274 %35 +%276 = OpBitwiseAnd %6 %273 %27 +%277 = OpINotEqual %8 %276 %35 +%278 = OpLogicalNot %8 %277 +%279 = OpLogicalAnd %8 %278 %275 +OpSelectionMerge %281 None +OpBranchConditional %279 %280 %281 +%280 = OpLabel +%282 = OpRayQueryGetIntersectionTypeKHR %6 %269 %35 +%283 = OpIEqual %8 %282 %35 +%284 = OpSelect %6 %283 %87 %197 +%285 = OpAccessChain %33 %272 %35 +OpStore %285 %284 +%286 = OpINotEqual %8 %284 %35 +OpSelectionMerge %288 None +OpBranchConditional %286 %287 %288 +%287 = OpLabel +%289 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %269 %35 +%290 = OpRayQueryGetIntersectionInstanceIdKHR %6 %269 %35 +%291 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %269 %35 +%292 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %269 %35 +%293 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %269 %35 +%294 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %269 %35 +%295 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %269 %35 +%296 = OpAccessChain %33 %272 %90 +OpStore %296 %289 +%297 = OpAccessChain %33 %272 %197 +OpStore %297 %290 +%298 = OpAccessChain %33 %272 %27 +OpStore %298 %291 +%299 = OpAccessChain %33 %272 %200 +OpStore %299 %292 +%300 = OpAccessChain %33 %272 %157 +OpStore %300 %293 +%301 = OpAccessChain %166 %272 %203 +OpStore %301 %294 +%302 = OpAccessChain %166 %272 %205 +OpStore %302 %295 +%303 = OpIEqual %8 %284 %87 +OpSelectionMerge %305 None +OpBranchConditional %286 %304 %305 +%304 = OpLabel +%306 = OpRayQueryGetIntersectionTKHR %3 %269 %35 +%307 = OpAccessChain %168 %272 %87 +OpStore %307 %306 +%308 = OpRayQueryGetIntersectionBarycentricsKHR %7 %269 %35 +%309 = OpRayQueryGetIntersectionFrontFaceKHR %8 %269 %35 +%310 = OpAccessChain %167 %272 %214 +OpStore %310 %308 +%311 = OpAccessChain %143 %272 %216 +OpStore %311 %309 +OpBranch %305 +%305 = OpLabel +OpBranch %288 +%288 = OpLabel +OpBranch %281 +%281 = OpLabel +%312 = OpLoad %10 %272 +OpReturnValue %312 +OpFunctionEnd +%320 = OpFunction %2 None %319 +%321 = OpFunctionParameter %32 +%322 = OpFunctionParameter %33 +%323 = OpFunctionParameter %3 +%324 = OpLabel +%327 = OpLoad %6 %322 +%328 = OpBitwiseAnd %6 %327 %90 +%329 = OpINotEqual %8 %328 %35 +%330 = OpBitwiseAnd %6 %327 %27 +%331 = OpINotEqual %8 %330 %35 +%332 = OpLogicalNot %8 %331 +%333 = OpLogicalAnd %8 %332 %329 +OpSelectionMerge %326 None +OpBranchConditional %333 %325 %326 +%325 = OpLabel +%334 = OpRayQueryGetIntersectionTypeKHR %6 %321 %35 +%335 = OpIEqual %8 %334 %87 +OpSelectionMerge %337 None +OpBranchConditional %335 %336 %337 +%336 = OpLabel +OpRayQueryGenerateIntersectionKHR %321 %323 +OpBranch %337 +%337 = OpLabel +OpBranch %326 +%326 = OpLabel +OpReturn +OpFunctionEnd +%345 = OpFunction %2 None %344 +%346 = OpFunctionParameter %32 +%347 = OpFunctionParameter %33 +%348 = OpLabel +%351 = OpLoad %6 %347 +%352 = OpBitwiseAnd %6 %351 %90 +%353 = OpINotEqual %8 %352 %35 +%354 = OpBitwiseAnd %6 %351 %27 +%355 = OpINotEqual %8 %354 %35 +%356 = OpLogicalNot %8 %355 +%357 = OpLogicalAnd %8 %356 %353 +OpSelectionMerge %350 None +OpBranchConditional %357 %349 %350 +%349 = OpLabel +%358 = OpRayQueryGetIntersectionTypeKHR %6 %346 %35 +%359 = OpIEqual %8 %358 %35 +OpSelectionMerge %361 None +OpBranchConditional %359 %360 %361 +%360 = OpLabel +OpRayQueryConfirmIntersectionKHR %346 +OpBranch %361 +%361 = OpLabel +OpBranch %350 +%350 = OpLabel +OpReturn +OpFunctionEnd +%260 = OpFunction %2 None %241 +%259 = OpLabel +%264 = OpVariable %32 Function +%265 = OpVariable %33 Function %35 +%261 = OpLoad %5 %15 +OpBranch %266 +%266 = OpLabel +%267 = OpFunctionCall %2 %40 %264 %261 %262 %265 +%313 = OpFunctionCall %10 %268 %264 %265 +%314 = OpCompositeExtract %6 %313 0 +%315 = OpIEqual %8 %314 %197 +OpSelectionMerge %316 None +OpBranchConditional %315 %317 %318 +%317 = OpLabel +%338 = OpFunctionCall %2 %320 %264 %265 %263 +OpReturn +%318 = OpLabel +%339 = OpCompositeExtract %6 %313 0 +%340 = OpIEqual %8 %339 %87 +OpSelectionMerge %341 None +OpBranchConditional %340 %342 %343 +%342 = OpLabel +%362 = OpFunctionCall %2 %345 %264 %265 +OpReturn +%343 = OpLabel +OpReturn +%341 = OpLabel +OpBranch %316 +%316 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-ray-query.spvasm b/naga/tests/out/spv/wgsl-ray-query.spvasm index d49ae40b2f8..1e82468bf0a 100644 --- a/naga/tests/out/spv/wgsl-ray-query.spvasm +++ b/naga/tests/out/spv/wgsl-ray-query.spvasm @@ -1,16 +1,16 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 218 +; Bound: 363 OpCapability Shader OpCapability RayQueryKHR OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %140 "main" %15 %17 -OpEntryPoint GLCompute %160 "main_candidate" %15 -OpExecutionMode %140 LocalSize 1 1 1 -OpExecutionMode %160 LocalSize 1 1 1 +OpEntryPoint GLCompute %240 "main" %15 %17 +OpEntryPoint GLCompute %260 "main_candidate" %15 +OpExecutionMode %240 LocalSize 1 1 1 +OpExecutionMode %260 LocalSize 1 1 1 OpMemberDecorate %10 0 Offset 0 OpMemberDecorate %10 1 Offset 4 OpMemberDecorate %10 2 Offset 8 @@ -64,93 +64,217 @@ OpMemberDecorate %18 0 Offset 0 %29 = OpConstant %3 0.1 %30 = OpConstant %3 100 %32 = OpTypePointer Function %11 -%45 = OpTypeVector %6 2 -%46 = OpTypePointer Function %45 -%47 = OpTypeVector %8 2 -%48 = OpConstant %6 0 -%49 = OpConstantComposite %45 %48 %48 -%50 = OpConstant %6 1 -%51 = OpConstant %6 4294967295 -%52 = OpConstantComposite %45 %51 %51 -%69 = OpTypePointer Function %10 -%70 = OpTypePointer Function %6 -%71 = OpTypePointer Function %9 -%72 = OpTypePointer Function %7 -%73 = OpTypePointer Function %8 -%74 = OpTypePointer Function %3 -%75 = OpTypeFunction %10 %32 -%77 = OpConstantNull %10 -%93 = OpConstant %6 2 -%95 = OpConstant %6 3 -%98 = OpConstant %6 5 -%100 = OpConstant %6 6 -%102 = OpConstant %6 9 -%104 = OpConstant %6 10 -%113 = OpConstant %6 7 -%115 = OpConstant %6 8 -%123 = OpTypeFunction %4 %4 %10 -%124 = OpConstant %3 1 -%125 = OpConstant %3 2.4 -%126 = OpConstant %3 0 -%141 = OpTypeFunction %2 -%143 = OpTypePointer StorageBuffer %13 -%145 = OpConstantComposite %4 %126 %126 %126 -%146 = OpConstantComposite %4 %126 %124 %126 -%149 = OpTypePointer StorageBuffer %6 -%154 = OpTypePointer StorageBuffer %4 -%162 = OpConstantComposite %12 %27 %28 %29 %30 %145 %146 -%163 = OpConstant %3 10 -%76 = OpFunction %10 None %75 -%78 = OpFunctionParameter %32 -%79 = OpLabel -%80 = OpVariable %69 Function %77 -%81 = OpRayQueryGetIntersectionTypeKHR %6 %78 %50 -%82 = OpAccessChain %70 %80 %48 -OpStore %82 %81 -%83 = OpINotEqual %8 %81 %48 -OpSelectionMerge %85 None -OpBranchConditional %83 %84 %85 -%84 = OpLabel -%86 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %78 %50 -%87 = OpRayQueryGetIntersectionInstanceIdKHR %6 %78 %50 -%88 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %78 %50 -%89 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %78 %50 -%90 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %78 %50 -%91 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %78 %50 -%92 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %78 %50 -%94 = OpAccessChain %70 %80 %93 -OpStore %94 %86 -%96 = OpAccessChain %70 %80 %95 -OpStore %96 %87 -%97 = OpAccessChain %70 %80 %27 -OpStore %97 %88 -%99 = OpAccessChain %70 %80 %98 -OpStore %99 %89 -%101 = OpAccessChain %70 %80 %100 -OpStore %101 %90 -%103 = OpAccessChain %71 %80 %102 -OpStore %103 %91 -%105 = OpAccessChain %71 %80 %104 -OpStore %105 %92 -%106 = OpIEqual %8 %81 %50 -%109 = OpRayQueryGetIntersectionTKHR %3 %78 %50 -%110 = OpAccessChain %74 %80 %50 -OpStore %110 %109 -OpSelectionMerge %108 None -OpBranchConditional %83 %107 %108 -%107 = OpLabel -%111 = OpRayQueryGetIntersectionBarycentricsKHR %7 %78 %50 -%112 = OpRayQueryGetIntersectionFrontFaceKHR %8 %78 %50 -%114 = OpAccessChain %72 %80 %113 -OpStore %114 %111 -%116 = OpAccessChain %73 %80 %115 -OpStore %116 %112 -OpBranch %108 -%108 = OpLabel -OpBranch %85 -%85 = OpLabel -%117 = OpLoad %10 %80 -OpReturnValue %117 +%33 = OpTypePointer Function %6 +%35 = OpConstant %6 0 +%38 = OpTypeVector %8 3 +%39 = OpTypeFunction %2 %32 %5 %12 %33 +%54 = OpConstant %3 0 +%67 = OpConstant %6 256 +%70 = OpConstant %6 512 +%75 = OpConstant %6 16 +%78 = OpConstant %6 32 +%87 = OpConstant %6 1 +%90 = OpConstant %6 2 +%93 = OpConstant %6 64 +%96 = OpConstant %6 128 +%125 = OpTypeVector %6 2 +%126 = OpTypePointer Function %125 +%127 = OpTypeVector %8 2 +%128 = OpConstantComposite %125 %35 %35 +%129 = OpConstant %6 4294967295 +%130 = OpConstantComposite %125 %129 %129 +%143 = OpTypePointer Function %8 +%144 = OpTypeFunction %8 %32 %33 +%150 = OpConstantFalse %8 +%157 = OpConstant %6 6 +%165 = OpTypePointer Function %10 +%166 = OpTypePointer Function %9 +%167 = OpTypePointer Function %7 +%168 = OpTypePointer Function %3 +%169 = OpTypeFunction %10 %32 %33 +%171 = OpConstantNull %10 +%197 = OpConstant %6 3 +%200 = OpConstant %6 5 +%203 = OpConstant %6 9 +%205 = OpConstant %6 10 +%214 = OpConstant %6 7 +%216 = OpConstant %6 8 +%224 = OpTypeFunction %4 %4 %10 +%225 = OpConstant %3 1 +%226 = OpConstant %3 2.4 +%241 = OpTypeFunction %2 +%243 = OpTypePointer StorageBuffer %13 +%245 = OpConstantComposite %4 %54 %54 %54 +%246 = OpConstantComposite %4 %54 %225 %54 +%249 = OpTypePointer StorageBuffer %6 +%254 = OpTypePointer StorageBuffer %4 +%262 = OpConstantComposite %12 %27 %28 %29 %30 %245 %246 +%263 = OpConstant %3 10 +%319 = OpTypeFunction %2 %32 %33 %3 +%344 = OpTypeFunction %2 %32 %33 +%40 = OpFunction %2 None %39 +%41 = OpFunctionParameter %32 +%42 = OpFunctionParameter %5 +%43 = OpFunctionParameter %12 +%44 = OpFunctionParameter %33 +%45 = OpLabel +%46 = OpCompositeExtract %6 %43 0 +%47 = OpCompositeExtract %6 %43 1 +%48 = OpCompositeExtract %3 %43 2 +%49 = OpCompositeExtract %3 %43 3 +%50 = OpCompositeExtract %4 %43 4 +%51 = OpCompositeExtract %4 %43 5 +%52 = OpFOrdLessThanEqual %8 %48 %49 +%53 = OpFOrdGreaterThanEqual %8 %48 %54 +%55 = OpIsInf %38 %50 +%56 = OpAny %8 %55 +%57 = OpIsNan %38 %50 +%58 = OpAny %8 %57 +%59 = OpLogicalOr %8 %58 %56 +%60 = OpLogicalNot %8 %59 +%61 = OpIsInf %38 %51 +%62 = OpAny %8 %61 +%63 = OpIsNan %38 %51 +%64 = OpAny %8 %63 +%65 = OpLogicalOr %8 %64 %62 +%66 = OpLogicalNot %8 %65 +%68 = OpBitwiseAnd %6 %46 %67 +%69 = OpINotEqual %8 %68 %35 +%71 = OpBitwiseAnd %6 %46 %70 +%72 = OpINotEqual %8 %71 %35 +%73 = OpLogicalAnd %8 %72 %69 +%74 = OpLogicalNot %8 %73 +%76 = OpBitwiseAnd %6 %46 %75 +%77 = OpINotEqual %8 %76 %35 +%79 = OpBitwiseAnd %6 %46 %78 +%80 = OpINotEqual %8 %79 %35 +%81 = OpLogicalAnd %8 %80 %69 +%82 = OpLogicalAnd %8 %80 %77 +%83 = OpLogicalAnd %8 %77 %69 +%84 = OpLogicalOr %8 %83 %81 +%85 = OpLogicalOr %8 %84 %82 +%86 = OpLogicalNot %8 %85 +%88 = OpBitwiseAnd %6 %46 %87 +%89 = OpINotEqual %8 %88 %35 +%91 = OpBitwiseAnd %6 %46 %90 +%92 = OpINotEqual %8 %91 %35 +%94 = OpBitwiseAnd %6 %46 %93 +%95 = OpINotEqual %8 %94 %35 +%97 = OpBitwiseAnd %6 %46 %96 +%98 = OpINotEqual %8 %97 %35 +%99 = OpLogicalAnd %8 %98 %89 +%100 = OpLogicalAnd %8 %98 %92 +%101 = OpLogicalAnd %8 %98 %95 +%102 = OpLogicalAnd %8 %95 %89 +%103 = OpLogicalAnd %8 %95 %92 +%104 = OpLogicalAnd %8 %92 %89 +%105 = OpLogicalOr %8 %104 %99 +%106 = OpLogicalOr %8 %105 %100 +%107 = OpLogicalOr %8 %106 %101 +%108 = OpLogicalOr %8 %107 %102 +%109 = OpLogicalOr %8 %108 %103 +%110 = OpLogicalNot %8 %109 +%111 = OpLogicalAnd %8 %52 %53 +%112 = OpLogicalAnd %8 %60 %66 +%113 = OpLogicalAnd %8 %74 %86 +%114 = OpLogicalAnd %8 %113 %110 +%115 = OpLogicalAnd %8 %111 %112 +%116 = OpLogicalAnd %8 %115 %114 +OpSelectionMerge %117 None +OpBranchConditional %116 %119 %118 +%119 = OpLabel +OpRayQueryInitializeKHR %41 %42 %46 %47 %50 %48 %51 %49 +OpStore %44 %87 +OpBranch %117 +%118 = OpLabel +OpBranch %117 +%117 = OpLabel +OpReturn +OpFunctionEnd +%145 = OpFunction %8 None %144 +%146 = OpFunctionParameter %32 +%147 = OpFunctionParameter %33 +%148 = OpLabel +%149 = OpVariable %143 Function %150 +%151 = OpLoad %6 %147 +%154 = OpBitwiseAnd %6 %151 %87 +%155 = OpINotEqual %8 %154 %35 +OpSelectionMerge %152 None +OpBranchConditional %155 %153 %152 +%153 = OpLabel +%156 = OpRayQueryProceedKHR %8 %146 +OpStore %149 %156 +%158 = OpSelect %6 %156 %90 %157 +%159 = OpBitwiseOr %6 %151 %158 +OpStore %147 %159 +OpBranch %152 +%152 = OpLabel +%160 = OpLoad %8 %149 +OpReturnValue %160 +OpFunctionEnd +%170 = OpFunction %10 None %169 +%172 = OpFunctionParameter %32 +%173 = OpFunctionParameter %33 +%174 = OpLabel +%175 = OpVariable %165 Function %171 +%176 = OpLoad %6 %173 +%177 = OpBitwiseAnd %6 %176 %90 +%178 = OpINotEqual %8 %177 %35 +%179 = OpBitwiseAnd %6 %176 %27 +%180 = OpINotEqual %8 %179 %35 +%181 = OpLogicalAnd %8 %180 %178 +OpSelectionMerge %183 None +OpBranchConditional %181 %182 %183 +%182 = OpLabel +%184 = OpRayQueryGetIntersectionTypeKHR %6 %172 %87 +%185 = OpAccessChain %33 %175 %35 +OpStore %185 %184 +%186 = OpINotEqual %8 %184 %35 +OpSelectionMerge %188 None +OpBranchConditional %186 %187 %188 +%187 = OpLabel +%189 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %172 %87 +%190 = OpRayQueryGetIntersectionInstanceIdKHR %6 %172 %87 +%191 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %172 %87 +%192 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %172 %87 +%193 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %172 %87 +%194 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %172 %87 +%195 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %172 %87 +%196 = OpAccessChain %33 %175 %90 +OpStore %196 %189 +%198 = OpAccessChain %33 %175 %197 +OpStore %198 %190 +%199 = OpAccessChain %33 %175 %27 +OpStore %199 %191 +%201 = OpAccessChain %33 %175 %200 +OpStore %201 %192 +%202 = OpAccessChain %33 %175 %157 +OpStore %202 %193 +%204 = OpAccessChain %166 %175 %203 +OpStore %204 %194 +%206 = OpAccessChain %166 %175 %205 +OpStore %206 %195 +%207 = OpIEqual %8 %184 %87 +%210 = OpRayQueryGetIntersectionTKHR %3 %172 %87 +%211 = OpAccessChain %168 %175 %87 +OpStore %211 %210 +OpSelectionMerge %209 None +OpBranchConditional %186 %208 %209 +%208 = OpLabel +%212 = OpRayQueryGetIntersectionBarycentricsKHR %7 %172 %87 +%213 = OpRayQueryGetIntersectionFrontFaceKHR %8 %172 %87 +%215 = OpAccessChain %167 %175 %214 +OpStore %215 %212 +%217 = OpAccessChain %143 %175 %216 +OpStore %217 %213 +OpBranch %209 +%209 = OpLabel +OpBranch %188 +%188 = OpLabel +OpBranch %183 +%183 = OpLabel +%218 = OpLoad %10 %175 +OpReturnValue %218 OpFunctionEnd %25 = OpFunction %10 None %26 %21 = OpFunctionParameter %4 @@ -158,179 +282,235 @@ OpFunctionEnd %23 = OpFunctionParameter %16 %20 = OpLabel %31 = OpVariable %32 Function -%53 = OpVariable %46 Function %52 +%34 = OpVariable %33 Function %35 +%131 = OpVariable %126 Function %130 %24 = OpLoad %5 %23 -OpBranch %33 -%33 = OpLabel -%34 = OpCompositeConstruct %12 %27 %28 %29 %30 %21 %22 -%35 = OpCompositeExtract %6 %34 0 -%36 = OpCompositeExtract %6 %34 1 -%37 = OpCompositeExtract %3 %34 2 -%38 = OpCompositeExtract %3 %34 3 -%39 = OpCompositeExtract %4 %34 4 -%40 = OpCompositeExtract %4 %34 5 -OpRayQueryInitializeKHR %31 %24 %35 %36 %39 %37 %40 %38 -OpBranch %41 -%41 = OpLabel -OpLoopMerge %42 %44 None -OpBranch %54 -%54 = OpLabel -%55 = OpLoad %45 %53 -%56 = OpIEqual %47 %49 %55 -%57 = OpAll %8 %56 -OpSelectionMerge %58 None -OpBranchConditional %57 %42 %58 -%58 = OpLabel -%59 = OpCompositeExtract %6 %55 1 -%60 = OpIEqual %8 %59 %48 -%61 = OpSelect %6 %60 %50 %48 -%62 = OpCompositeConstruct %45 %61 %50 -%63 = OpISub %45 %55 %62 -OpStore %53 %63 -OpBranch %43 -%43 = OpLabel -%64 = OpRayQueryProceedKHR %8 %31 -OpSelectionMerge %65 None -OpBranchConditional %64 %65 %66 -%66 = OpLabel -OpBranch %42 -%65 = OpLabel -OpBranch %67 -%67 = OpLabel -OpBranch %68 -%68 = OpLabel -OpBranch %44 -%44 = OpLabel -OpBranch %41 -%42 = OpLabel -%118 = OpFunctionCall %10 %76 %31 -OpReturnValue %118 +OpBranch %36 +%36 = OpLabel +%37 = OpCompositeConstruct %12 %27 %28 %29 %30 %21 %22 +%120 = OpFunctionCall %2 %40 %31 %24 %37 %34 +OpBranch %121 +%121 = OpLabel +OpLoopMerge %122 %124 None +OpBranch %132 +%132 = OpLabel +%133 = OpLoad %125 %131 +%134 = OpIEqual %127 %128 %133 +%135 = OpAll %8 %134 +OpSelectionMerge %136 None +OpBranchConditional %135 %122 %136 +%136 = OpLabel +%137 = OpCompositeExtract %6 %133 1 +%138 = OpIEqual %8 %137 %35 +%139 = OpSelect %6 %138 %87 %35 +%140 = OpCompositeConstruct %125 %139 %87 +%141 = OpISub %125 %133 %140 +OpStore %131 %141 +OpBranch %123 +%123 = OpLabel +%142 = OpFunctionCall %8 %145 %31 %34 +OpSelectionMerge %161 None +OpBranchConditional %142 %161 %162 +%162 = OpLabel +OpBranch %122 +%161 = OpLabel +OpBranch %163 +%163 = OpLabel +OpBranch %164 +%164 = OpLabel +OpBranch %124 +%124 = OpLabel +OpBranch %121 +%122 = OpLabel +%219 = OpFunctionCall %10 %170 %31 %34 +OpReturnValue %219 OpFunctionEnd -%122 = OpFunction %4 None %123 -%120 = OpFunctionParameter %4 -%121 = OpFunctionParameter %10 -%119 = OpLabel -OpBranch %127 -%127 = OpLabel -%128 = OpCompositeExtract %9 %121 10 -%129 = OpCompositeConstruct %14 %120 %124 -%130 = OpMatrixTimesVector %4 %128 %129 -%131 = OpVectorShuffle %7 %130 %130 0 1 -%132 = OpExtInst %7 %1 Normalize %131 -%133 = OpVectorTimesScalar %7 %132 %125 -%134 = OpCompositeExtract %9 %121 9 -%135 = OpCompositeConstruct %14 %133 %126 %124 -%136 = OpMatrixTimesVector %4 %134 %135 -%137 = OpFSub %4 %120 %136 -%138 = OpExtInst %4 %1 Normalize %137 -OpReturnValue %138 +%223 = OpFunction %4 None %224 +%221 = OpFunctionParameter %4 +%222 = OpFunctionParameter %10 +%220 = OpLabel +OpBranch %227 +%227 = OpLabel +%228 = OpCompositeExtract %9 %222 10 +%229 = OpCompositeConstruct %14 %221 %225 +%230 = OpMatrixTimesVector %4 %228 %229 +%231 = OpVectorShuffle %7 %230 %230 0 1 +%232 = OpExtInst %7 %1 Normalize %231 +%233 = OpVectorTimesScalar %7 %232 %226 +%234 = OpCompositeExtract %9 %222 9 +%235 = OpCompositeConstruct %14 %233 %54 %225 +%236 = OpMatrixTimesVector %4 %234 %235 +%237 = OpFSub %4 %221 %236 +%238 = OpExtInst %4 %1 Normalize %237 +OpReturnValue %238 OpFunctionEnd -%140 = OpFunction %2 None %141 -%139 = OpLabel -%142 = OpLoad %5 %15 -%144 = OpAccessChain %143 %17 %48 -OpBranch %147 -%147 = OpLabel -%148 = OpFunctionCall %10 %25 %145 %146 %15 -%150 = OpCompositeExtract %6 %148 0 -%151 = OpIEqual %8 %150 %48 -%152 = OpSelect %6 %151 %50 %48 -%153 = OpAccessChain %149 %144 %48 -OpStore %153 %152 -%155 = OpCompositeExtract %3 %148 1 -%156 = OpVectorTimesScalar %4 %146 %155 -%157 = OpFunctionCall %4 %122 %156 %148 -%158 = OpAccessChain %154 %144 %50 -OpStore %158 %157 +%240 = OpFunction %2 None %241 +%239 = OpLabel +%242 = OpLoad %5 %15 +%244 = OpAccessChain %243 %17 %35 +OpBranch %247 +%247 = OpLabel +%248 = OpFunctionCall %10 %25 %245 %246 %15 +%250 = OpCompositeExtract %6 %248 0 +%251 = OpIEqual %8 %250 %35 +%252 = OpSelect %6 %251 %87 %35 +%253 = OpAccessChain %249 %244 %35 +OpStore %253 %252 +%255 = OpCompositeExtract %3 %248 1 +%256 = OpVectorTimesScalar %4 %246 %255 +%257 = OpFunctionCall %4 %223 %256 %248 +%258 = OpAccessChain %254 %244 %87 +OpStore %258 %257 OpReturn OpFunctionEnd -%172 = OpFunction %10 None %75 -%173 = OpFunctionParameter %32 -%174 = OpLabel -%175 = OpVariable %69 Function %77 -%176 = OpRayQueryGetIntersectionTypeKHR %6 %173 %48 -%177 = OpIEqual %8 %176 %48 -%178 = OpSelect %6 %177 %50 %95 -%179 = OpAccessChain %70 %175 %48 -OpStore %179 %178 -%180 = OpINotEqual %8 %178 %48 -OpSelectionMerge %182 None -OpBranchConditional %180 %181 %182 -%181 = OpLabel -%183 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %173 %48 -%184 = OpRayQueryGetIntersectionInstanceIdKHR %6 %173 %48 -%185 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %173 %48 -%186 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %173 %48 -%187 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %173 %48 -%188 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %173 %48 -%189 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %173 %48 -%190 = OpAccessChain %70 %175 %93 -OpStore %190 %183 -%191 = OpAccessChain %70 %175 %95 -OpStore %191 %184 -%192 = OpAccessChain %70 %175 %27 -OpStore %192 %185 -%193 = OpAccessChain %70 %175 %98 -OpStore %193 %186 -%194 = OpAccessChain %70 %175 %100 -OpStore %194 %187 -%195 = OpAccessChain %71 %175 %102 -OpStore %195 %188 -%196 = OpAccessChain %71 %175 %104 -OpStore %196 %189 -%197 = OpIEqual %8 %178 %50 -OpSelectionMerge %199 None -OpBranchConditional %180 %198 %199 -%198 = OpLabel -%200 = OpRayQueryGetIntersectionTKHR %3 %173 %48 -%201 = OpAccessChain %74 %175 %50 -OpStore %201 %200 -%202 = OpRayQueryGetIntersectionBarycentricsKHR %7 %173 %48 -%203 = OpRayQueryGetIntersectionFrontFaceKHR %8 %173 %48 -%204 = OpAccessChain %72 %175 %113 -OpStore %204 %202 -%205 = OpAccessChain %73 %175 %115 -OpStore %205 %203 -OpBranch %199 -%199 = OpLabel -OpBranch %182 -%182 = OpLabel -%206 = OpLoad %10 %175 -OpReturnValue %206 +%268 = OpFunction %10 None %169 +%269 = OpFunctionParameter %32 +%270 = OpFunctionParameter %33 +%271 = OpLabel +%272 = OpVariable %165 Function %171 +%273 = OpLoad %6 %270 +%274 = OpBitwiseAnd %6 %273 %90 +%275 = OpINotEqual %8 %274 %35 +%276 = OpBitwiseAnd %6 %273 %27 +%277 = OpINotEqual %8 %276 %35 +%278 = OpLogicalNot %8 %277 +%279 = OpLogicalAnd %8 %278 %275 +OpSelectionMerge %281 None +OpBranchConditional %279 %280 %281 +%280 = OpLabel +%282 = OpRayQueryGetIntersectionTypeKHR %6 %269 %35 +%283 = OpIEqual %8 %282 %35 +%284 = OpSelect %6 %283 %87 %197 +%285 = OpAccessChain %33 %272 %35 +OpStore %285 %284 +%286 = OpINotEqual %8 %284 %35 +OpSelectionMerge %288 None +OpBranchConditional %286 %287 %288 +%287 = OpLabel +%289 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %269 %35 +%290 = OpRayQueryGetIntersectionInstanceIdKHR %6 %269 %35 +%291 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %269 %35 +%292 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %269 %35 +%293 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %269 %35 +%294 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %269 %35 +%295 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %269 %35 +%296 = OpAccessChain %33 %272 %90 +OpStore %296 %289 +%297 = OpAccessChain %33 %272 %197 +OpStore %297 %290 +%298 = OpAccessChain %33 %272 %27 +OpStore %298 %291 +%299 = OpAccessChain %33 %272 %200 +OpStore %299 %292 +%300 = OpAccessChain %33 %272 %157 +OpStore %300 %293 +%301 = OpAccessChain %166 %272 %203 +OpStore %301 %294 +%302 = OpAccessChain %166 %272 %205 +OpStore %302 %295 +%303 = OpIEqual %8 %284 %87 +OpSelectionMerge %305 None +OpBranchConditional %286 %304 %305 +%304 = OpLabel +%306 = OpRayQueryGetIntersectionTKHR %3 %269 %35 +%307 = OpAccessChain %168 %272 %87 +OpStore %307 %306 +%308 = OpRayQueryGetIntersectionBarycentricsKHR %7 %269 %35 +%309 = OpRayQueryGetIntersectionFrontFaceKHR %8 %269 %35 +%310 = OpAccessChain %167 %272 %214 +OpStore %310 %308 +%311 = OpAccessChain %143 %272 %216 +OpStore %311 %309 +OpBranch %305 +%305 = OpLabel +OpBranch %288 +%288 = OpLabel +OpBranch %281 +%281 = OpLabel +%312 = OpLoad %10 %272 +OpReturnValue %312 +OpFunctionEnd +%320 = OpFunction %2 None %319 +%321 = OpFunctionParameter %32 +%322 = OpFunctionParameter %33 +%323 = OpFunctionParameter %3 +%324 = OpLabel +%327 = OpLoad %6 %322 +%328 = OpBitwiseAnd %6 %327 %90 +%329 = OpINotEqual %8 %328 %35 +%330 = OpBitwiseAnd %6 %327 %27 +%331 = OpINotEqual %8 %330 %35 +%332 = OpLogicalNot %8 %331 +%333 = OpLogicalAnd %8 %332 %329 +OpSelectionMerge %326 None +OpBranchConditional %333 %325 %326 +%325 = OpLabel +%334 = OpRayQueryGetIntersectionTypeKHR %6 %321 %35 +%335 = OpIEqual %8 %334 %87 +OpSelectionMerge %337 None +OpBranchConditional %335 %336 %337 +%336 = OpLabel +OpRayQueryGenerateIntersectionKHR %321 %323 +OpBranch %337 +%337 = OpLabel +OpBranch %326 +%326 = OpLabel +OpReturn +OpFunctionEnd +%345 = OpFunction %2 None %344 +%346 = OpFunctionParameter %32 +%347 = OpFunctionParameter %33 +%348 = OpLabel +%351 = OpLoad %6 %347 +%352 = OpBitwiseAnd %6 %351 %90 +%353 = OpINotEqual %8 %352 %35 +%354 = OpBitwiseAnd %6 %351 %27 +%355 = OpINotEqual %8 %354 %35 +%356 = OpLogicalNot %8 %355 +%357 = OpLogicalAnd %8 %356 %353 +OpSelectionMerge %350 None +OpBranchConditional %357 %349 %350 +%349 = OpLabel +%358 = OpRayQueryGetIntersectionTypeKHR %6 %346 %35 +%359 = OpIEqual %8 %358 %35 +OpSelectionMerge %361 None +OpBranchConditional %359 %360 %361 +%360 = OpLabel +OpRayQueryConfirmIntersectionKHR %346 +OpBranch %361 +%361 = OpLabel +OpBranch %350 +%350 = OpLabel +OpReturn OpFunctionEnd -%160 = OpFunction %2 None %141 -%159 = OpLabel -%164 = OpVariable %32 Function -%161 = OpLoad %5 %15 -OpBranch %165 -%165 = OpLabel -%166 = OpCompositeExtract %6 %162 0 -%167 = OpCompositeExtract %6 %162 1 -%168 = OpCompositeExtract %3 %162 2 -%169 = OpCompositeExtract %3 %162 3 -%170 = OpCompositeExtract %4 %162 4 -%171 = OpCompositeExtract %4 %162 5 -OpRayQueryInitializeKHR %164 %161 %166 %167 %170 %168 %171 %169 -%207 = OpFunctionCall %10 %172 %164 -%208 = OpCompositeExtract %6 %207 0 -%209 = OpIEqual %8 %208 %95 -OpSelectionMerge %210 None -OpBranchConditional %209 %211 %212 -%211 = OpLabel -OpRayQueryGenerateIntersectionKHR %164 %163 +%260 = OpFunction %2 None %241 +%259 = OpLabel +%264 = OpVariable %32 Function +%265 = OpVariable %33 Function %35 +%261 = OpLoad %5 %15 +OpBranch %266 +%266 = OpLabel +%267 = OpFunctionCall %2 %40 %264 %261 %262 %265 +%313 = OpFunctionCall %10 %268 %264 %265 +%314 = OpCompositeExtract %6 %313 0 +%315 = OpIEqual %8 %314 %197 +OpSelectionMerge %316 None +OpBranchConditional %315 %317 %318 +%317 = OpLabel +%338 = OpFunctionCall %2 %320 %264 %265 %263 OpReturn -%212 = OpLabel -%213 = OpCompositeExtract %6 %207 0 -%214 = OpIEqual %8 %213 %50 -OpSelectionMerge %215 None -OpBranchConditional %214 %216 %217 -%216 = OpLabel -OpRayQueryConfirmIntersectionKHR %164 +%318 = OpLabel +%339 = OpCompositeExtract %6 %313 0 +%340 = OpIEqual %8 %339 %87 +OpSelectionMerge %341 None +OpBranchConditional %340 %342 %343 +%342 = OpLabel +%362 = OpFunctionCall %2 %345 %264 %265 OpReturn -%217 = OpLabel +%343 = OpLabel OpReturn -%215 = OpLabel -OpBranch %210 -%210 = OpLabel +%341 = OpLabel +OpBranch %316 +%316 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/tests/wgpu-gpu/ray_tracing/shader.rs b/tests/tests/wgpu-gpu/ray_tracing/shader.rs index fcd29af52e6..feb47e5b1a0 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/shader.rs +++ b/tests/tests/wgpu-gpu/ray_tracing/shader.rs @@ -1,17 +1,19 @@ use crate::ray_tracing::{acceleration_structure_limits, AsBuildContext}; +use wgpu::util::{BufferInitDescriptor, DeviceExt}; use wgpu::{ - include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor, + include_wgsl, Backends, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, }; use wgpu::{AccelerationStructureFlags, BufferUsages}; use wgpu_macros::gpu_test; -use wgpu_test::GpuTestInitializer; +use wgpu_test::{FailureCase, GpuTestInitializer}; use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext}; const STRUCT_SIZE: wgpu::BufferAddress = 176; pub fn all_tests(tests: &mut Vec) { tests.push(ACCESS_ALL_STRUCT_MEMBERS); + tests.push(PREVENT_INVALID_RAY_QUERY_CALLS); } #[gpu_test] @@ -103,3 +105,95 @@ fn access_all_struct_members(ctx: TestingContext) { ctx.queue.submit([encoder_compute.finish()]); } + +#[gpu_test] +static PREVENT_INVALID_RAY_QUERY_CALLS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .test_features_limits() + .limits(acceleration_structure_limits()) + .features(wgpu::Features::EXPERIMENTAL_RAY_QUERY) + // not yet implemented in directx12 + .skip(FailureCase::backend(Backends::DX12)), + ) + .run_sync(prevent_invalid_ray_query_calls); + +fn prevent_invalid_ray_query_calls(ctx: TestingContext) { + let invalid_values_buffer = ctx.device.create_buffer_init(&BufferInitDescriptor { + label: Some("invalid values buffer"), + contents: bytemuck::cast_slice(&[f32::NAN, f32::INFINITY]), + usage: BufferUsages::STORAGE, + }); + + // + // Create a clean `AsBuildContext` + // + + let as_ctx = AsBuildContext::new( + &ctx, + AccelerationStructureFlags::empty(), + AccelerationStructureFlags::empty(), + ); + + let mut encoder_build = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor { + label: Some("Build"), + }); + + encoder_build.build_acceleration_structures([&as_ctx.blas_build_entry()], [&as_ctx.tlas]); + + ctx.queue.submit([encoder_build.finish()]); + + // + // Create shader + // + + let shader = ctx + .device + .create_shader_module(include_wgsl!("shader.wgsl")); + let compute_pipeline = ctx + .device + .create_compute_pipeline(&ComputePipelineDescriptor { + label: None, + layout: None, + module: &shader, + entry_point: Some("invalid_usages"), + compilation_options: Default::default(), + cache: None, + }); + + let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor { + label: None, + layout: &compute_pipeline.get_bind_group_layout(0), + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::AccelerationStructure(&as_ctx.tlas), + }, + BindGroupEntry { + binding: 1, + resource: BindingResource::Buffer(invalid_values_buffer.as_entire_buffer_binding()), + }, + ], + }); + + // + // Submit once to check for no issues + // + + let mut encoder_compute = ctx + .device + .create_command_encoder(&CommandEncoderDescriptor::default()); + { + let mut pass = encoder_compute.begin_compute_pass(&ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + pass.set_pipeline(&compute_pipeline); + pass.set_bind_group(0, Some(&bind_group), &[]); + pass.dispatch_workgroups(1, 1, 1) + } + + ctx.queue.submit([encoder_compute.finish()]); +} diff --git a/tests/tests/wgpu-gpu/ray_tracing/shader.wgsl b/tests/tests/wgpu-gpu/ray_tracing/shader.wgsl index 2130b8d9ae6..55a8f4b85d6 100644 --- a/tests/tests/wgpu-gpu/ray_tracing/shader.wgsl +++ b/tests/tests/wgpu-gpu/ray_tracing/shader.wgsl @@ -48,4 +48,78 @@ fn all_of_struct() { intersection.world_to_object, intersection.object_to_world, ); +} + +struct MaybeInvalidValues { + nan: f32, + inf: f32, +} + +@group(0) @binding(1) +var invalid_values: MaybeInvalidValues; + +@workgroup_size(1) +@compute +fn invalid_usages() { + { + var rq: ray_query; + // no initialize + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0))); + // no proceed + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + // The acceleration structure has been set up to not generate an intersections, meaning it will be a committed intersection, not candidate. + let intersection = rayQueryGetCandidateIntersection(&rq); + } + { + var rq: ray_query; + // NaN in origin + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, invalid_values.nan, 0.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + // Inf in origin + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, invalid_values.inf, 0.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + // NaN in direction + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, invalid_values.nan, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + // Inf in direction + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, invalid_values.inf, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + // t_min greater than t_max + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 100000.0, 0.1, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } + { + var rq: ray_query; + // t_min less than 0 + rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, -0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0))); + rayQueryProceed(&rq); + let intersection = rayQueryGetCommittedIntersection(&rq); + } } \ No newline at end of file diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index c7e69b63c13..626b390ece1 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -290,6 +290,7 @@ impl super::Device { || stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing || stage.module.runtime_checks.force_loop_bounding != layout.naga_options.force_loop_bounding; + // Note: ray query initialization tracking not yet implemented let mut temp_options; let naga_options = if needs_temp_options { temp_options = layout.naga_options.clone(); diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index a01b1f7fe50..916bfe7a6e4 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -2151,6 +2151,12 @@ impl super::Adapter { // But this requires cloning the `spv::Options` struct, which has heap allocations. true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS` ); + flags.set( + spv::WriterFlags::PRINT_ON_RAY_QUERY_INITIALIZATION_FAIL, + self.instance.flags.contains(wgt::InstanceFlags::DEBUG) + && (self.instance.instance_api_version >= vk::API_VERSION_1_3 + || enabled_extensions.contains(&khr::shader_non_semantic_info::NAME)), + ); if features.contains(wgt::Features::EXPERIMENTAL_RAY_QUERY) { capabilities.push(spv::Capability::RayQueryKHR); } @@ -2206,6 +2212,7 @@ impl super::Adapter { spv::ZeroInitializeWorkgroupMemoryMode::Polyfill }, force_loop_bounding: true, + ray_query_initialization_tracking: true, use_storage_input_output_16: features.contains(wgt::Features::SHADER_F16) && self.phd_features.supports_storage_input_output_16(), fake_missing_bindings: false, diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index dbd30df1015..4e77f269b7f 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -765,6 +765,7 @@ impl super::Device { }; let needs_temp_options = !runtime_checks.bounds_checks || !runtime_checks.force_loop_bounding + || !runtime_checks.ray_query_initialization_tracking || !binding_map.is_empty() || naga_shader.debug_source.is_some() || !stage.zero_initialize_workgroup_memory; @@ -782,6 +783,9 @@ impl super::Device { if !runtime_checks.force_loop_bounding { temp_options.force_loop_bounding = false; } + if !runtime_checks.ray_query_initialization_tracking { + temp_options.ray_query_initialization_tracking = false; + } if !binding_map.is_empty() { temp_options.binding_map = binding_map.clone(); } diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 6f9cf415b80..b5395588fd1 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -7883,6 +7883,18 @@ pub struct ShaderRuntimeChecks { /// conclusions about other safety-critical code paths. This option SHOULD NOT be disabled /// when running untrusted code. pub force_loop_bounding: bool, + /// If false, the caller **MUST** ensure that in all passed shaders every ray query + /// has (wgsl naming) `rayQueryInitialize` must have called before `rayQueryProceed`, `rayQueryProceed` + /// must have been called, returned true and have hit an AABB before `rayQueryGenerateIntersection` + /// is called, `rayQueryProceed` must have been called, returned true and have hit + /// a triangle before `rayQueryConfirmIntersection` is called, `rayQueryProceed` + /// must have been called and have returned true before `rayQueryTerminate`, + /// `getCandidateHitVertexPositions` or `rayQueryGetCandidateIntersection` is called, + /// and `rayQueryProceed` must have been called and have returned false before `rayQueryGetCommittedIntersection` + /// or `getCommittedHitVertexPositions` are called + /// + /// It is the aim that these cases will not cause UB if this is set to true, but currently this will still happen on DX12 and Metal. + pub ray_query_initialization_tracking: bool, } impl ShaderRuntimeChecks { @@ -7915,6 +7927,7 @@ impl ShaderRuntimeChecks { Self { bounds_checks: all_checks, force_loop_bounding: all_checks, + ray_query_initialization_tracking: all_checks, } } }