From 3c93f5b389b2e24ecc56ed1ca67f7d15800880a5 Mon Sep 17 00:00:00 2001 From: NiiRoZz Date: Sat, 23 Nov 2024 15:17:52 +0100 Subject: [PATCH] Add support for named external block --- include/NZSL/Ast/Compare.inl | 6 + include/NZSL/Ast/Nodes.hpp | 2 + include/NZSL/Ast/SanitizeVisitor.hpp | 2 + src/NZSL/Ast/Cloner.cpp | 2 + src/NZSL/Ast/ReflectVisitor.cpp | 2 +- src/NZSL/Ast/SanitizeVisitor.cpp | 49 +++++++- src/NZSL/GlslWriter.cpp | 2 +- src/NZSL/LangWriter.cpp | 13 +- src/NZSL/Parser.cpp | 8 +- src/NZSL/SpirvWriter.cpp | 2 +- tests/src/Tests/ErrorsTests.cpp | 41 +++++++ tests/src/Tests/ExternalTests.cpp | 170 +++++++++++++++++++++++++++ 12 files changed, 287 insertions(+), 12 deletions(-) diff --git a/include/NZSL/Ast/Compare.inl b/include/NZSL/Ast/Compare.inl index a8edbb6a..da301433 100644 --- a/include/NZSL/Ast/Compare.inl +++ b/include/NZSL/Ast/Compare.inl @@ -524,6 +524,9 @@ namespace nzsl::Ast if (!Compare(lhs.variableId, rhs.variableId, params)) return false; + if (!Compare(lhs.prefix, rhs.prefix, params)) + return false; + return true; } @@ -609,6 +612,9 @@ namespace nzsl::Ast if (!Compare(lhs.tag, rhs.tag, params)) return false; + if (!Compare(lhs.name, rhs.name, params)) + return false; + if (!Compare(lhs.externalVars, rhs.externalVars, params)) return false; diff --git a/include/NZSL/Ast/Nodes.hpp b/include/NZSL/Ast/Nodes.hpp index 6dc7ab99..c328db2b 100644 --- a/include/NZSL/Ast/Nodes.hpp +++ b/include/NZSL/Ast/Nodes.hpp @@ -242,6 +242,7 @@ namespace nzsl::Ast void Visit(ExpressionVisitor& visitor) override; std::size_t variableId; + std::string prefix; }; struct NZSL_API UnaryExpression : Expression @@ -347,6 +348,7 @@ namespace nzsl::Ast SourceLocation sourceLocation; }; + std::string name; std::string tag; std::vector externalVars; ExpressionValue bindingSet; diff --git a/include/NZSL/Ast/SanitizeVisitor.hpp b/include/NZSL/Ast/SanitizeVisitor.hpp index 1205c9e7..811079d0 100644 --- a/include/NZSL/Ast/SanitizeVisitor.hpp +++ b/include/NZSL/Ast/SanitizeVisitor.hpp @@ -148,6 +148,7 @@ namespace nzsl::Ast std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex); void RegisterReservedName(std::string name); + void RegisterExternalName(std::string name, const SourceLocation& sourceLocation); std::size_t RegisterStruct(std::string name, std::optional description, std::optional index, const SourceLocation& sourceLocation); std::size_t RegisterType(std::string name, std::optional expressionType, std::optional index, const SourceLocation& sourceLocation); std::size_t RegisterType(std::string name, std::optional partialType, std::optional index, const SourceLocation& sourceLocation); @@ -207,6 +208,7 @@ namespace nzsl::Ast { Alias, Constant, + External, Function, Intrinsic, Module, diff --git a/src/NZSL/Ast/Cloner.cpp b/src/NZSL/Ast/Cloner.cpp index 9343884f..6f6deb75 100644 --- a/src/NZSL/Ast/Cloner.cpp +++ b/src/NZSL/Ast/Cloner.cpp @@ -128,6 +128,7 @@ namespace nzsl::Ast clone->autoBinding = Clone(node.autoBinding); clone->bindingSet = Clone(node.bindingSet); clone->tag = node.tag; + clone->name = node.name; clone->externalVars.reserve(node.externalVars.size()); for (const auto& var : node.externalVars) @@ -593,6 +594,7 @@ namespace nzsl::Ast { auto clone = std::make_unique(); clone->variableId = node.variableId; + clone->prefix = node.prefix; clone->cachedExpressionType = node.cachedExpressionType; clone->sourceLocation = node.sourceLocation; diff --git a/src/NZSL/Ast/ReflectVisitor.cpp b/src/NZSL/Ast/ReflectVisitor.cpp index 72a42f1c..2254ab6b 100644 --- a/src/NZSL/Ast/ReflectVisitor.cpp +++ b/src/NZSL/Ast/ReflectVisitor.cpp @@ -57,7 +57,7 @@ namespace nzsl::Ast for (const auto& extVar : node.externalVars) { if (extVar.varIndex) - m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex, extVar.sourceLocation); + m_callbacks->onVariableIndex(node.name + extVar.name, *extVar.varIndex, extVar.sourceLocation); } } diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index 4120d45e..6da897ba 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -312,11 +312,12 @@ namespace nzsl::Ast MandatoryExpr(node.expr, node.sourceLocation); - // Handle module access (TODO: Add namespace expression?) + // Handle module access and named external access (TODO: Add namespace expression?) if (node.expr->GetType() == NodeType::IdentifierExpression && node.identifiers.size() == 1) { auto& identifierExpr = static_cast(*node.expr); const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier); + if (identifierData && identifierData->category == IdentifierCategory::Module) { std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation); @@ -331,6 +332,17 @@ namespace nzsl::Ast return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); } } + + if (identifierData && identifierData->category == IdentifierCategory::External) + { + identifierData = FindIdentifier(identifierExpr.identifier + node.identifiers.front().identifier); + if (identifierData) + { + auto variableValuePtr = HandleIdentifier(identifierData, node.identifiers.front().sourceLocation); + static_cast(variableValuePtr.get())->prefix = identifierExpr.identifier; + return variableValuePtr; + } + } } ExpressionPtr indexedExpr = CloneExpression(node.expr); @@ -1462,27 +1474,36 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized") } }; + if (!clone->name.empty()) + { + RegisterExternalName(clone->name, clone->sourceLocation); + } + bool hasUnresolved = false; Nz::StackVector autoBindingEntries = NazaraStackVector(std::size_t, clone->externalVars.size()); for (std::size_t i = 0; i < clone->externalVars.size(); ++i) { auto& extVar = clone->externalVars[i]; + + SanitizeIdentifier(extVar.name, IdentifierScope::ExternalVariable); + + std::string fullName = clone->name + extVar.name; Context::UsedExternalData usedBindingData; usedBindingData.conditionalStatementIndex = m_context->currentConditionalIndex; - if (auto it = m_context->declaredExternalVar.find(extVar.name); it != m_context->declaredExternalVar.end()) + if (auto it = m_context->declaredExternalVar.find(fullName); it != m_context->declaredExternalVar.end()) { if (it->second.conditionalStatementIndex == m_context->currentConditionalIndex || usedBindingData.conditionalStatementIndex == m_context->currentConditionalIndex) throw CompilerExtAlreadyDeclaredError{ extVar.sourceLocation, extVar.name }; } - m_context->declaredExternalVar.emplace(extVar.name, usedBindingData); + m_context->declaredExternalVar.emplace(fullName, usedBindingData); std::optional resolvedType = ResolveTypeExpr(extVar.type, false, node.sourceLocation); if (!resolvedType.has_value()) { - RegisterUnresolved(extVar.name); + RegisterUnresolved(fullName); hasUnresolved = true; continue; } @@ -1553,8 +1574,7 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized") } extVar.type = std::move(resolvedType).value(); - extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex, extVar.sourceLocation); - SanitizeIdentifier(extVar.name, IdentifierScope::ExternalVariable); + extVar.varIndex = RegisterVariable(fullName, std::move(varType), extVar.varIndex, extVar.sourceLocation); } // Resolve auto-binding entries when explicit binding are known @@ -2866,6 +2886,9 @@ NAZARA_WARNING_POP() return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression } + case IdentifierCategory::External: + throw AstUnexpectedIdentifierError{ sourceLocation, "external" }; + case IdentifierCategory::Function: { // Replace IdentifierExpression by FunctionExpression @@ -3737,6 +3760,20 @@ NAZARA_WARNING_POP() }); } + void SanitizeVisitor::RegisterExternalName(std::string name, const SourceLocation& sourceLocation) + { + if (!IsIdentifierAvailable(name)) + throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name }; + + m_context->currentEnv->identifiersInScope.push_back({ + std::move(name), + { + std::numeric_limits::max(), + IdentifierCategory::External + } + }); + } + std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional description, std::optional index, const SourceLocation& sourceLocation) { bool unresolved = false; diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index b39b7dd7..149464fd 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -2267,7 +2267,7 @@ namespace nzsl AppendComment("struct tag: " + structInfo.desc->tag); } - std::string varName = externalVar.name + m_currentState->moduleSuffix; + std::string varName = node.name + externalVar.name + m_currentState->moduleSuffix; // Layout handling bool hasLayout = false; diff --git a/src/NZSL/LangWriter.cpp b/src/NZSL/LangWriter.cpp index bb807bb9..1203ac74 100644 --- a/src/NZSL/LangWriter.cpp +++ b/src/NZSL/LangWriter.cpp @@ -1260,6 +1260,11 @@ namespace nzsl void LangWriter::Visit(Ast::VariableValueExpression& node) { + if (!node.prefix.empty()) + { + Append(node.prefix, '.'); + } + AppendIdentifier(m_currentState->variables, node.variableId); } @@ -1366,7 +1371,13 @@ namespace nzsl void LangWriter::Visit(Ast::DeclareExternalStatement& node) { AppendAttributes(true, SetAttribute{ node.bindingSet }, AutoBindingAttribute{ node.autoBinding }, TagAttribute{ node.tag }); - AppendLine("external"); + Append("external"); + + if (!node.name.empty()) + Append(" ", node.name); + + AppendLine(); + EnterScope(); bool first = true; diff --git a/src/NZSL/Parser.cpp b/src/NZSL/Parser.cpp index 86df213b..4dbf01f2 100644 --- a/src/NZSL/Parser.cpp +++ b/src/NZSL/Parser.cpp @@ -708,11 +708,15 @@ namespace nzsl NAZARA_USE_ANONYMOUS_NAMESPACE const Token& externalToken = Expect(Advance(), TokenType::External); - Expect(Advance(), TokenType::OpenCurlyBracket); - + std::unique_ptr externalStatement = std::make_unique(); externalStatement->sourceLocation = externalToken.location; + if (const Token& peekToken = Peek(); peekToken.type == TokenType::Identifier) + externalStatement->name = ParseIdentifierAsName(nullptr); + + Expect(Advance(), TokenType::OpenCurlyBracket); + Ast::ExpressionValue condition; for (auto&& attribute : attributes) diff --git a/src/NZSL/SpirvWriter.cpp b/src/NZSL/SpirvWriter.cpp index fd9338c8..ad56241c 100644 --- a/src/NZSL/SpirvWriter.cpp +++ b/src/NZSL/SpirvWriter.cpp @@ -192,7 +192,7 @@ namespace nzsl ExternalVar& extVarData = extVars[*extVar.varIndex]; SpirvConstantCache::Variable variable; - variable.debugName = extVar.name; + variable.debugName = node.name + extVar.name; const Ast::ExpressionType& extVarType = extVar.type.GetResultingValue(); diff --git a/tests/src/Tests/ErrorsTests.cpp b/tests/src/Tests/ErrorsTests.cpp index 8b887821..db1ad5bf 100644 --- a/tests/src/Tests/ErrorsTests.cpp +++ b/tests/src/Tests/ErrorsTests.cpp @@ -636,6 +636,47 @@ fn main() } )"), "(28,19 -> 28): CFunctionCallUnmatchingParameterType error: function GetValue parameter #0 type mismatch (expected array[struct Inner, 3], got array[uniform[struct Inner], 3])"); + CHECK_THROWS_WITH(Compile(R"( +[nzsl_version("1.0")] +module; + +external Viewer +{ + [tag("Color map")] + [binding(0)] tex: sampler2D[f32] +} + +external Viewer +{ + [tag("Color map")] + [binding(1)] tex2: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let value = Viewer.tex.Sample(vec2[f32](0.0, 0.0)); +} +)"), "(11,1 -> 8): CIdentifierAlreadyUsed error: identifier Viewer is already used"); + + CHECK_THROWS_WITH(Compile(R"( +[nzsl_version("1.0")] +module; + +external Viewer +{ + [tag("Color map")] + [binding(0)] tex: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let Viewer = 0.0; + let value = Viewer.tex.Sample(vec2[f32](0.0, 0.0)); +} +)"), "(14,2 -> 18): CIdentifierAlreadyUsed error: identifier Viewer is already used"); + } /************************************************************************/ diff --git a/tests/src/Tests/ExternalTests.cpp b/tests/src/Tests/ExternalTests.cpp index 3b09a680..f3a28882 100644 --- a/tests/src/Tests/ExternalTests.cpp +++ b/tests/src/Tests/ExternalTests.cpp @@ -1278,4 +1278,174 @@ fn main() OpReturn OpFunctionEnd)", {}, {}, true); } + + SECTION("named external") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +external Viewer +{ + [tag("Color map")] + [binding(0)] tex: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let value = Viewer.tex.Sample(vec2[f32](0.0, 0.0)); +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + shaderModule = SanitizeModule(*shaderModule); + + ExpectGLSL(*shaderModule, R"( +// fragment shader - this file was generated by NZSL compiler (Nazara Shading Language) + +precision highp int; +#if GL_FRAGMENT_PRECISION_HIGH +precision highp float; +precision highp sampler2D; +#else +precision mediump float; +precision mediump sampler2D; +#endif + +// header end + +// external var tag: Color map +uniform sampler2D Viewertex; + +void main() +{ + vec4 value = texture(Viewertex, vec2(0.0, 0.0)); +} +)"); + + ExpectNZSL(*shaderModule, R"( +external Viewer +{ + [set(0), binding(0), tag("Color map")] tex: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let value: vec4[f32] = Viewer.tex.Sample(vec2[f32](0.0, 0.0)); +})"); + + ExpectSPIRV(*shaderModule, R"( + %1 = OpTypeFloat 32 + %2 = OpTypeImage %1 Dim(Dim2D) 0 0 0 1 ImageFormat(Unknown) + %3 = OpTypeSampledImage %2 + %4 = OpTypePointer StorageClass(UniformConstant) %3 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpConstant %1 f32(0) + %9 = OpTypeVector %1 2 +%10 = OpTypeVector %1 4 +%11 = OpTypePointer StorageClass(Function) %10 + %5 = OpVariable %4 StorageClass(UniformConstant) +%12 = OpFunction %6 FunctionControl(0) %7 +%13 = OpLabel +%14 = OpVariable %11 StorageClass(Function) +%15 = OpLoad %3 %5 +%16 = OpCompositeConstruct %9 %8 %8 +%17 = OpImageSampleImplicitLod %10 %15 %16 + OpStore %14 %17 + OpReturn + OpFunctionEnd)", {}, {}, true); + } + + SECTION("named external shadowing") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +external Viewer +{ + [tag("Color map")] + [binding(0)] tex: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let Viewertex = 0.0; + let value = Viewertex; +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + shaderModule = SanitizeModule(*shaderModule); + + ExpectGLSL(*shaderModule, R"( +// fragment shader - this file was generated by NZSL compiler (Nazara Shading Language) + +precision highp int; +#if GL_FRAGMENT_PRECISION_HIGH +precision highp float; +precision highp sampler2D; +#else +precision mediump float; +precision mediump sampler2D; +#endif + +// header end + +// external var tag: Color map +uniform sampler2D Viewertex; + +void main() +{ + float Viewertex_2 = 0.0; + float value = Viewertex_2; +} +)"); + + ExpectNZSL(*shaderModule, R"( +external Viewer +{ + [set(0), binding(0), tag("Color map")] tex: sampler2D[f32] +} + +[entry(frag)] +fn main() +{ + let Viewertex: f32 = 0.0; + let value: f32 = Viewertex; +})"); + + ExpectSPIRV(*shaderModule, R"( + OpCapability Capability(Shader) + OpMemoryModel AddressingModel(Logical) MemoryModel(GLSL450) + OpEntryPoint ExecutionModel(Fragment) %10 "main" + OpExecutionMode %10 ExecutionMode(OriginUpperLeft) + OpSource SourceLanguage(NZSL) 100 + OpName %5 "Viewertex" + OpName %10 "main" + OpDecorate %5 Decoration(Binding) 0 + OpDecorate %5 Decoration(DescriptorSet) 0 + %1 = OpTypeFloat 32 + %2 = OpTypeImage %1 Dim(Dim2D) 0 0 0 1 ImageFormat(Unknown) + %3 = OpTypeSampledImage %2 + %4 = OpTypePointer StorageClass(UniformConstant) %3 + %6 = OpTypeVoid + %7 = OpTypeFunction %6 + %8 = OpConstant %1 f32(0) + %9 = OpTypePointer StorageClass(Function) %1 + %5 = OpVariable %4 StorageClass(UniformConstant) +%10 = OpFunction %6 FunctionControl(0) %7 +%11 = OpLabel +%12 = OpVariable %9 StorageClass(Function) +%13 = OpVariable %9 StorageClass(Function) + OpStore %12 %8 +%14 = OpLoad %1 %12 + OpStore %13 %14 + OpReturn + OpFunctionEnd)", {}, {}, true); + } }