diff --git a/include/NZSL/Parser.hpp b/include/NZSL/Parser.hpp index 88e13492..eb62cd07 100644 --- a/include/NZSL/Parser.hpp +++ b/include/NZSL/Parser.hpp @@ -58,7 +58,7 @@ namespace nzsl Ast::ExpressionPtr BuildUnary(Ast::UnaryType unaryType, Ast::ExpressionPtr expr); // Statements - Ast::StatementPtr ParseAliasDeclaration(); + Ast::StatementPtr ParseAliasDeclaration(std::vector attributes = {}); Ast::StatementPtr ParseBranchStatement(); Ast::StatementPtr ParseBreakStatement(); Ast::StatementPtr ParseConstStatement(std::vector attributes = {}); diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index c0892a67..00349ab2 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -3541,8 +3541,14 @@ namespace nzsl::Ast std::size_t SanitizeVisitor::RegisterAlias(std::string name, std::optional aliasData, std::optional index, const SourceLocation& sourceLocation) { - if (!IsIdentifierAvailable(name)) - throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; + bool unresolved = false; + if (const IdentifierData* identifierData = FindIdentifier(name)) + { + if (!m_context->inConditionalStatement || !identifierData->isConditional) + throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; + else + unresolved = true; + } std::size_t aliasIndex; if (aliasData) @@ -3555,14 +3561,19 @@ namespace nzsl::Ast else aliasIndex = m_context->aliases.RegisterNewIndex(true); - m_context->currentEnv->identifiersInScope.push_back({ - std::move(name), - { - aliasIndex, - IdentifierCategory::Alias, - m_context->inConditionalStatement - } - }); + if (!unresolved) + { + m_context->currentEnv->identifiersInScope.push_back({ + std::move(name), + { + aliasIndex, + IdentifierCategory::Alias, + m_context->inConditionalStatement + } + }); + } + else + RegisterUnresolved(std::move(name)); return aliasIndex; } diff --git a/src/NZSL/Parser.cpp b/src/NZSL/Parser.cpp index 5244c94c..4d7cf447 100644 --- a/src/NZSL/Parser.cpp +++ b/src/NZSL/Parser.cpp @@ -518,10 +518,25 @@ namespace nzsl return unaryExpr; } - Ast::StatementPtr Parser::ParseAliasDeclaration() + Ast::StatementPtr Parser::ParseAliasDeclaration(std::vector attributes) { const Token& aliasToken = Expect(Advance(), TokenType::Alias); + Ast::ExpressionValue condition; + + for (auto&& attribute : attributes) + { + switch (attribute.type) + { + case Ast::AttributeType::Cond: + HandleUniqueAttribute(condition, std::move(attribute)); + break; + + default: + throw ParserUnexpectedAttributeError{ attribute.sourceLocation, attribute.type, "alias declaration" }; + } + } + std::string name = ParseIdentifierAsName(nullptr); Expect(Advance(), TokenType::Assign); @@ -533,7 +548,10 @@ namespace nzsl auto aliasStatement = ShaderBuilder::DeclareAlias(std::move(name), std::move(expr)); aliasStatement->sourceLocation = SourceLocation::BuildFromTo(aliasToken.location, endToken.location); - return aliasStatement; + if (condition.HasValue()) + return ShaderBuilder::ConditionalStatement(std::move(condition).GetExpression(), std::move(aliasStatement)); + else + return aliasStatement; } Ast::StatementPtr Parser::ParseBranchStatement() @@ -1069,10 +1087,7 @@ namespace nzsl switch (nextToken.type) { case TokenType::Alias: - if (!attributes.empty()) - throw ParserUnexpectedAttributeError{ attributes.front().sourceLocation, attributes.front().type, "alias declaration" }; - - return ParseAliasDeclaration(); + return ParseAliasDeclaration(std::move(attributes)); case TokenType::Const: return ParseConstStatement(std::move(attributes)); diff --git a/tests/src/Tests/AliasTests.cpp b/tests/src/Tests/AliasTests.cpp index b669c44d..7380791b 100644 --- a/tests/src/Tests/AliasTests.cpp +++ b/tests/src/Tests/AliasTests.cpp @@ -93,4 +93,201 @@ OpStore OpReturn OpFunctionEnd)"); } + + SECTION("Conditional aliases") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct ForwardOutput +{ + [location(0)] color: vec4[f32] +} + +struct DeferredOutput +{ + [location(0)] color: vec4[f32], + [location(1)] normal: vec3[f32] +} + +option ForwardPass: bool; + +[cond(ForwardPass)] +alias FragOut = ForwardOutput; + +[cond(!ForwardPass)] +alias FragOut = DeferredOutput; + +[entry(frag)] +fn main() -> FragOut +{ + let output: FragOut; + output.color = vec4[f32](0.0, 0.0, 1.0, 1.0); + const if (!ForwardPass) + output.normal = vec3[f32](0.0, 1.0, 0.0); + + return output; +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + + WHEN("We perform a partial sanitization") + { + nzsl::Ast::SanitizeVisitor::Options options; + options.partialSanitization = true; + + shaderModule = SanitizeModule(*shaderModule, options); + } + + WHEN("We enable ForwardPass") + { + nzsl::Ast::SanitizeVisitor::Options options; + options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = true; + options.removeOptionDeclaration = true; + + shaderModule = SanitizeModule(*shaderModule, options); + + ExpectGLSL(*shaderModule, R"( +struct ForwardOutput +{ + vec4 color; +}; + +struct DeferredOutput +{ + vec4 color; + vec3 normal; +}; + +/*************** Outputs ***************/ +layout(location = 0) out vec4 _nzslOutcolor; + +void main() +{ + ForwardOutput output_; + output_.color = vec4(0.0, 0.0, 1.0, 1.0); + + _nzslOutcolor = output_.color; + return; +} +)"); + + ExpectNZSL(*shaderModule, R"( +struct ForwardOutput +{ + [location(0)] color: vec4[f32] +} + +struct DeferredOutput +{ + [location(0)] color: vec4[f32], + [location(1)] normal: vec3[f32] +} + +alias FragOut = ForwardOutput; + +[entry(frag)] +fn main() -> FragOut +{ + let output: FragOut; + output.color = vec4[f32](0.0, 0.0, 1.0, 1.0); + return output; +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpLabel +OpVariable +OpCompositeConstruct +OpAccessChain +OpStore +OpLoad +OpCompositeExtract +OpStore +OpReturn +OpFunctionEnd)"); + } + + WHEN("We disable ForwardPass") + { + nzsl::Ast::SanitizeVisitor::Options options; + options.optionValues[nzsl::Ast::HashOption("ForwardPass")] = false; + options.removeOptionDeclaration = true; + + shaderModule = SanitizeModule(*shaderModule, options); + + ExpectGLSL(*shaderModule, R"( +struct ForwardOutput +{ + vec4 color; +}; + +struct DeferredOutput +{ + vec4 color; + vec3 normal; +}; + +/*************** Outputs ***************/ +layout(location = 0) out vec4 _nzslOutcolor; +layout(location = 1) out vec3 _nzslOutnormal; + +void main() +{ + DeferredOutput output_; + output_.color = vec4(0.0, 0.0, 1.0, 1.0); + output_.normal = vec3(0.0, 1.0, 0.0); + + _nzslOutcolor = output_.color; + _nzslOutnormal = output_.normal; + return; +} +)"); + + ExpectNZSL(*shaderModule, R"( +struct ForwardOutput +{ + [location(0)] color: vec4[f32] +} + +struct DeferredOutput +{ + [location(0)] color: vec4[f32], + [location(1)] normal: vec3[f32] +} + +alias FragOut = DeferredOutput; + +[entry(frag)] +fn main() -> FragOut +{ + let output: FragOut; + output.color = vec4[f32](0.0, 0.0, 1.0, 1.0); + output.normal = vec3[f32](0.0, 1.0, 0.0); + return output; +} +)"); + + ExpectSPIRV(*shaderModule, R"( +OpFunction +OpLabel +OpVariable +OpCompositeConstruct +OpAccessChain +OpStore +OpCompositeConstruct +OpAccessChain +OpStore +OpLoad +OpCompositeExtract +OpStore +OpCompositeExtract +OpStore +OpReturn +OpFunctionEnd)"); + } + } } diff --git a/tests/src/Tests/ErrorsTests.cpp b/tests/src/Tests/ErrorsTests.cpp index 39727f57..58425519 100644 --- a/tests/src/Tests/ErrorsTests.cpp +++ b/tests/src/Tests/ErrorsTests.cpp @@ -68,9 +68,9 @@ module; [nzsl_version("1.0")] module; -[cond(false)] +[layout(std140)] alias vec3f32 = vec3[f32]; -)"), "(5,2 -> 12): PUnexpectedAttribute error: unexpected attribute cond on alias declaration"); +)"), "(5,2 -> 15): PUnexpectedAttribute error: unexpected attribute layout on alias declaration"); // import statements don't support cond attribute CHECK_THROWS_WITH(nzsl::Parse(R"(