From 6f9343181501eb7046544b4ed36feefe057f5ff2 Mon Sep 17 00:00:00 2001 From: SirLynix Date: Wed, 3 Jul 2024 21:10:07 +0200 Subject: [PATCH] Compiler: Remove passes from SanitizeVisitor this is to get a better view of the work remaining --- .github/workflows/coverage.yml | 4 +- .github/workflows/linux-build.yml | 2 +- .github/workflows/macos-build.yml | 2 +- .github/workflows/msys2-build.yml | 2 +- .github/workflows/windows-build.yml | 2 +- include/NZSL/Ast/SanitizeVisitor.hpp | 6 - src/NZSL/Ast/SanitizeVisitor.cpp | 479 ++----------------------- src/NZSL/GlslWriter.cpp | 6 +- src/NZSL/SpirvWriter.cpp | 10 +- tests/src/Tests/SanitizationsTests.cpp | 313 +++++++++++----- 10 files changed, 279 insertions(+), 547 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 0e18535..a3c28e2 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -108,14 +108,14 @@ jobs: - name: Run unit tests and generate coverage output (Linux) if: runner.os == 'Linux' run: | - xmake run UnitTests + xmake run UnitTests sanitizing gcovr -b -x coverage.out -s -f 'include/NZSL/.*' -f 'src/NZSL/.*' -e 'src/NZSL/SpirV/SpirvData.cpp' build/.objs/ - name: Run unit tests and generate coverage output (Windows) if: runner.os == 'Windows' shell: cmd run: | - "C:\Program Files\OpenCppCoverage\OpenCppCoverage.exe" --export_type cobertura:coverage.out --sources "ShaderLang\include\NZSL\*" --sources "ShaderLang\src\NZSL\*" --excluded_sources "ShaderLang\src\NZSL\SpirV\SpirvData.cpp" --modules "ShaderLang\bin\*" --cover_children -- xmake run UnitTests + "C:\Program Files\OpenCppCoverage\OpenCppCoverage.exe" --export_type cobertura:coverage.out --sources "ShaderLang\include\NZSL\*" --sources "ShaderLang\src\NZSL\*" --excluded_sources "ShaderLang\src\NZSL\SpirV\SpirvData.cpp" --modules "ShaderLang\bin\*" --cover_children -- xmake run UnitTests sanitizing - name: Upload Coverage Report to Codecov uses: codecov/codecov-action@v4 diff --git a/.github/workflows/linux-build.yml b/.github/workflows/linux-build.yml index 73ef739..ea9724a 100644 --- a/.github/workflows/linux-build.yml +++ b/.github/workflows/linux-build.yml @@ -85,7 +85,7 @@ jobs: # Run unit tests - name: Run unit tests - run: xmake run UnitTests + run: xmake run UnitTests sanitizing # Install the result files - name: Install NZSL diff --git a/.github/workflows/macos-build.yml b/.github/workflows/macos-build.yml index 82378b9..14ba246 100644 --- a/.github/workflows/macos-build.yml +++ b/.github/workflows/macos-build.yml @@ -75,7 +75,7 @@ jobs: # Run unit tests - name: Run unit tests - run: xmake run UnitTests + run: xmake run UnitTests sanitizing # Install the result files - name: Install NZSL diff --git a/.github/workflows/msys2-build.yml b/.github/workflows/msys2-build.yml index 3db966f..e667e42 100644 --- a/.github/workflows/msys2-build.yml +++ b/.github/workflows/msys2-build.yml @@ -95,7 +95,7 @@ jobs: # Run unit tests - name: Run unit tests - run: xmake run UnitTests + run: xmake run UnitTests sanitizing # Install the result files - name: Install NZSL diff --git a/.github/workflows/windows-build.yml b/.github/workflows/windows-build.yml index eeb1232..2d404ef 100644 --- a/.github/workflows/windows-build.yml +++ b/.github/workflows/windows-build.yml @@ -81,7 +81,7 @@ jobs: # Run unit tests - name: Run unit tests - run: xmake run UnitTests + run: xmake run UnitTests sanitizing # Install the result files - name: Install NZSL diff --git a/include/NZSL/Ast/SanitizeVisitor.hpp b/include/NZSL/Ast/SanitizeVisitor.hpp index 2f84381..ea2dc6a 100644 --- a/include/NZSL/Ast/SanitizeVisitor.hpp +++ b/include/NZSL/Ast/SanitizeVisitor.hpp @@ -47,16 +47,10 @@ namespace nzsl::Ast bool forceAutoBindingResolve = false; bool makeVariableNameUnique = false; bool partialSanitization = false; - bool reduceLoopsToWhile = false; bool removeAliases = false; - bool removeCompoundAssignments = false; bool removeConstArraySize = false; - bool removeMatrixBinaryAddSub = false; - bool removeMatrixCast = false; bool removeOptionDeclaration = false; - bool removeScalarSwizzling = false; bool removeSingleConstDeclaration = false; - bool splitMultipleBranches = false; bool splitWrappedArrayAssignation = false; bool splitWrappedStructAssignation = false; bool useIdentifierAccessesForStructs = true; diff --git a/src/NZSL/Ast/SanitizeVisitor.cpp b/src/NZSL/Ast/SanitizeVisitor.cpp index c0892a6..2fbc8f5 100644 --- a/src/NZSL/Ast/SanitizeVisitor.cpp +++ b/src/NZSL/Ast/SanitizeVisitor.cpp @@ -531,50 +531,16 @@ namespace nzsl::Ast if (swizzleComponentCount > 4) throw CompilerInvalidSwizzleError{ identifierEntry.sourceLocation }; - if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedType)) - { - for (std::size_t j = 0; j < swizzleComponentCount; ++j) - { - if (ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation) != 0) - throw CompilerInvalidScalarSwizzleError{ identifierEntry.sourceLocation }; - } - - if (swizzleComponentCount == 1) - continue; //< ignore this swizzle (a.x == a) - - // Use a Cast expression to replace swizzle - indexedExpr = CacheResult(std::move(indexedExpr)); //< Since we are going to use a value multiple times, cache it if required - - PrimitiveType baseType; - if (IsVectorType(resolvedType)) - baseType = std::get(resolvedType).type; - else - baseType = std::get(resolvedType); + auto swizzle = std::make_unique(); + swizzle->expression = std::move(indexedExpr); - auto cast = std::make_unique(); - cast->targetType = ExpressionType{ VectorType{ swizzleComponentCount, baseType } }; - - cast->expressions.reserve(swizzleComponentCount); - for (std::size_t j = 0; j < swizzleComponentCount; ++j) - cast->expressions.push_back(CloneExpression(indexedExpr)); - - Validate(*cast); - - indexedExpr = std::move(cast); - } - else - { - auto swizzle = std::make_unique(); - swizzle->expression = std::move(indexedExpr); + swizzle->componentCount = swizzleComponentCount; + for (std::size_t j = 0; j < swizzleComponentCount; ++j) + swizzle->components[j] = ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation); - swizzle->componentCount = swizzleComponentCount; - for (std::size_t j = 0; j < swizzleComponentCount; ++j) - swizzle->components[j] = ToSwizzleIndex(identifierEntry.identifier[j], identifierEntry.sourceLocation); + Validate(*swizzle); - Validate(*swizzle); - - indexedExpr = std::move(swizzle); - } + indexedExpr = std::move(swizzle); } else throw CompilerUnexpectedAccessedTypeError{ node.sourceLocation }; @@ -713,48 +679,6 @@ namespace nzsl::Ast if (Validate(*clone) == ValidationResult::Unresolved) return clone; - if (m_context->options.removeMatrixBinaryAddSub && (clone->op == BinaryType::Add || clone->op == BinaryType::Subtract)) - { - const ExpressionType& leftExprType = GetExpressionTypeSecure(*clone->left); - const ExpressionType& rightExprType = GetExpressionTypeSecure(*clone->right); - if (IsMatrixType(leftExprType) && IsMatrixType(rightExprType)) - { - const MatrixType& matrixType = std::get(leftExprType); - assert(leftExprType == rightExprType); - - // Since we're going to access both matrices multiples times, make sure we cache them into variables if required - auto leftMatrix = CacheResult(std::move(clone->left)); - auto rightMatrix = CacheResult(std::move(clone->right)); - - std::vector columnExpressions(matrixType.columnCount); - - for (std::size_t i = 0; i < matrixType.columnCount; ++i) - { - // mat[i] - auto leftColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(leftMatrix), ShaderBuilder::ConstantValue(std::uint32_t(i), clone->sourceLocation)); - auto rightColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(rightMatrix), ShaderBuilder::ConstantValue(std::uint32_t(i), clone->sourceLocation)); - - Validate(*leftColumnExpr); - Validate(*rightColumnExpr); - - // lhs[i] +- rhs[i] - auto binOp = ShaderBuilder::Binary(clone->op, std::move(leftColumnExpr), std::move(rightColumnExpr)); - binOp->sourceLocation = clone->sourceLocation; - - Validate(*binOp); - - columnExpressions[i] = std::move(binOp); - } - - // Build resulting matrix - auto result = ShaderBuilder::Cast(leftExprType, std::move(columnExpressions)); - result->sourceLocation = clone->sourceLocation; - - // Re-clone resulting cast operation, so it can be transformed again if required - return Clone(*result); - } - } - return clone; } @@ -915,133 +839,6 @@ namespace nzsl::Ast if (Validate(*clone) == ValidationResult::Unresolved) return clone; //< unresolved - const ExpressionType& targetType = clone->targetType.GetResultingValue(); - - if (m_context->options.removeMatrixCast && IsMatrixType(targetType)) - { - const MatrixType& targetMatrixType = std::get(targetType); - - const ExpressionType& frontExprType = ResolveAlias(GetExpressionTypeSecure(*clone->expressions.front())); - bool isMatrixCast = IsMatrixType(frontExprType); - if (isMatrixCast && std::get(frontExprType) == targetMatrixType) - { - // Nothing to do - return std::move(clone->expressions.front()); - } - - auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", targetType); //< Validation will prevent name-clash if required - variableDeclaration->sourceLocation = node.sourceLocation; - Validate(*variableDeclaration); - - std::size_t variableIndex = *variableDeclaration->varIndex; - - m_context->currentStatementList->emplace_back(std::move(variableDeclaration)); - - ExpressionPtr cachedDiagonalValue; - - for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) - { - // temp[i] - auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, targetType, node.sourceLocation), ShaderBuilder::ConstantValue(std::uint32_t(i), node.sourceLocation)); - Validate(*columnExpr); - - // vector expression - ExpressionPtr vectorExpr; - std::size_t vectorComponentCount; - if (isMatrixCast) - { - // fromMatrix[i] - auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::ConstantValue(std::uint32_t(i), node.sourceLocation)); - Validate(*matrixColumnExpr); - - vectorExpr = std::move(matrixColumnExpr); - vectorComponentCount = std::get(frontExprType).rowCount; - } - else if (IsVectorType(frontExprType)) - { - // parameter #i - vectorExpr = std::move(clone->expressions[i]); - vectorComponentCount = std::get(ResolveAlias(GetExpressionTypeSecure(*vectorExpr))).componentCount; - } - else - { - assert(IsPrimitiveType(frontExprType)); - - // Use a Cast expression to replace swizzle - std::vector expressions(targetMatrixType.rowCount); - SourceLocation location; - for (std::size_t j = 0; j < targetMatrixType.rowCount; ++j) - { - if (clone->expressions.size() == 1) //< diagonal value - { - if (!cachedDiagonalValue) - cachedDiagonalValue = CacheResult(std::move(clone->expressions.front())); - - if (i == j) - expressions[j] = CloneExpression(cachedDiagonalValue); - else - expressions[j] = ShaderBuilder::ConstantValue(ExpressionType{ targetMatrixType.type }, 0, node.sourceLocation); - } - else - expressions[j] = std::move(clone->expressions[i * targetMatrixType.rowCount + j]); - - if (j == 0) - location = expressions[j]->sourceLocation; - else - location.ExtendToRight(expressions[j]->sourceLocation); - } - - auto buildVec = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); - buildVec->sourceLocation = location; - Validate(*buildVec); - - vectorExpr = std::move(buildVec); - vectorComponentCount = targetMatrixType.rowCount; - } - - // cast expression (turn fromMatrix[i] to vec3[f32](fromMatrix[i])) - ExpressionPtr castExpr; - if (vectorComponentCount != targetMatrixType.rowCount) - { - CastExpressionPtr vecCast; - if (vectorComponentCount < targetMatrixType.rowCount) - { - std::vector expressions; - expressions.push_back(std::move(vectorExpr)); - for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j) - expressions.push_back(ShaderBuilder::ConstantValue(ExpressionType{ targetMatrixType.type }, (i == j + vectorComponentCount) ? 1 : 0, node.sourceLocation)); //< set 1 to diagonal - - vecCast = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); - vecCast->sourceLocation = node.sourceLocation; - Validate(*vecCast); - - castExpr = std::move(vecCast); - } - else - { - std::array swizzleComponents; - std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0); - - auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount); - swizzleExpr->sourceLocation = node.sourceLocation; - Validate(*swizzleExpr); - - castExpr = std::move(swizzleExpr); - } - } - else - castExpr = std::move(vectorExpr); - - // temp[i] = castExpr - auto assignExpr = ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)); - assignExpr->sourceLocation = node.sourceLocation; - - m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(std::move(assignExpr))); - } - - return ShaderBuilder::Variable(variableIndex, targetType, node.sourceLocation); - } - return clone; } @@ -1189,50 +986,15 @@ namespace nzsl::Ast } const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + + auto clone = std::make_unique(); + clone->componentCount = node.componentCount; + clone->components = node.components; + clone->expression = std::move(expression); + clone->sourceLocation = node.sourceLocation; + Validate(*clone); - if (m_context->options.removeScalarSwizzling && IsPrimitiveType(resolvedExprType)) - { - for (std::size_t i = 0; i < node.componentCount; ++i) - { - if (node.components[i] != 0) - throw CompilerInvalidScalarSwizzleError{ node.sourceLocation }; - } - - if (node.componentCount == 1) - return expression; //< ignore this swizzle (a.x == a) - - // Use a Cast expression to replace swizzle - expression = CacheResult(std::move(expression)); //< Since we are going to use a value multiple times, cache it if required - - PrimitiveType baseType; - if (IsVectorType(resolvedExprType)) - baseType = std::get(resolvedExprType).type; - else - baseType = std::get(resolvedExprType); - - auto cast = std::make_unique(); - cast->sourceLocation = node.sourceLocation; - cast->targetType = ExpressionType{ VectorType{ node.componentCount, baseType } }; - - cast->expressions.reserve(node.componentCount); - for (std::size_t j = 0; j < node.componentCount; ++j) - cast->expressions.push_back(CloneExpression(expression)); - - Validate(*cast); - - return cast; - } - else - { - auto clone = std::make_unique(); - clone->componentCount = node.componentCount; - clone->components = node.components; - clone->expression = std::move(expression); - clone->sourceLocation = node.sourceLocation; - Validate(*clone); - - return clone; - } + return clone; } ExpressionPtr SanitizeVisitor::Clone(UnaryExpression& node) @@ -1308,21 +1070,8 @@ namespace nzsl::Ast return ValidationResult::Validated; }; - if (m_context->options.splitMultipleBranches && condIndex > 0) - { - auto currentBranch = std::make_unique(); - - if (BuildCondStatement(currentBranch->condStatements.emplace_back()) == ValidationResult::Unresolved) - return Cloner::Clone(node); - - root->elseStatement = std::move(currentBranch); - root = static_cast(root->elseStatement.get()); - } - else - { - if (BuildCondStatement(clone->condStatements.emplace_back()) == ValidationResult::Unresolved) - return Cloner::Clone(node); - } + if (BuildCondStatement(clone->condStatements.emplace_back()) == ValidationResult::Unresolved) + return Cloner::Clone(node); } if (node.elseStatement) @@ -2147,93 +1896,7 @@ namespace nzsl::Ast } } - if (m_context->options.reduceLoopsToWhile) - { - PushScope(); - Nz::CallOnExit unscoper([&] { PopScope(); }); - - auto multi = std::make_unique(); - multi->sourceLocation = node.sourceLocation; - - // Counter variable - auto counterVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(fromExpr)); - counterVariable->sourceLocation = node.sourceLocation; - counterVariable->varIndex = node.varIndex; - Validate(*counterVariable); - - std::size_t counterVarIndex = counterVariable->varIndex.value(); - multi->statements.emplace_back(std::move(counterVariable)); - - // Target variable - auto targetVariable = ShaderBuilder::DeclareVariable("to", std::move(toExpr)); - targetVariable->sourceLocation = node.sourceLocation; - Validate(*targetVariable); - - std::size_t targetVarIndex = targetVariable->varIndex.value(); - multi->statements.emplace_back(std::move(targetVariable)); - - // Step variable - std::optional stepVarIndex; - - if (stepExpr) - { - auto stepVariable = ShaderBuilder::DeclareVariable("step", std::move(stepExpr)); - stepVariable->sourceLocation = node.sourceLocation; - Validate(*stepVariable); - - stepVarIndex = stepVariable->varIndex; - multi->statements.emplace_back(std::move(stepVariable)); - } - - // While - auto whileStatement = std::make_unique(); - whileStatement->sourceLocation = node.sourceLocation; - whileStatement->unroll = std::move(unrollValue); - - // While condition - auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType, node.sourceLocation); - auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType, node.sourceLocation); - - auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable)); - condition->sourceLocation = node.sourceLocation; - Validate(*condition); - - whileStatement->condition = std::move(condition); - - // While body - auto body = std::make_unique(); - body->statements.reserve(2); - { - bool wasInLoop = m_context->inLoop; - m_context->inLoop = true; - Nz::CallOnExit restoreLoop([=] { m_context->inLoop = wasInLoop; }); - - body->statements.emplace_back(Unscope(CloneStatement(node.statement))); - } - - // Counter and increment - ExpressionPtr incrExpr; - if (stepVarIndex) - incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType, node.sourceLocation); - else - incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::ConstantValue(1, node.sourceLocation) : ShaderBuilder::ConstantValue(1u, node.sourceLocation); - - incrExpr->sourceLocation = node.sourceLocation; - - auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType, node.sourceLocation), std::move(incrExpr)); - incrCounter->sourceLocation = node.sourceLocation; - Validate(*incrCounter); - - body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); - - whileStatement->body = std::move(body); - - multi->statements.emplace_back(std::move(whileStatement)); - - return ShaderBuilder::Scoped(std::move(multi)); - } - else - return CloneFor(); + return CloneFor(); } StatementPtr SanitizeVisitor::Clone(ForEachStatement& node) @@ -2317,93 +1980,26 @@ namespace nzsl::Ast } } - if (m_context->options.reduceLoopsToWhile) - { - PushScope(); - Nz::CallOnExit unscoper([&] { PopScope(); }); - - auto multi = std::make_unique(); - multi->sourceLocation = node.sourceLocation; - - if (IsArrayType(resolvedExprType)) - { - const ArrayType& arrayType = std::get(resolvedExprType); - - multi->statements.reserve(2); - - // Counter variable - auto counterVariable = ShaderBuilder::DeclareVariable("i", ShaderBuilder::ConstantValue(0u)); - counterVariable->sourceLocation = node.sourceLocation; - - Validate(*counterVariable); - - std::size_t counterVarIndex = counterVariable->varIndex.value(); - - multi->statements.emplace_back(std::move(counterVariable)); - - auto whileStatement = std::make_unique(); - whileStatement->unroll = std::move(unrollValue); - - // While condition - auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, node.sourceLocation), ShaderBuilder::ConstantValue(arrayType.length, node.sourceLocation)); - Validate(*condition); - whileStatement->condition = std::move(condition); - - // While body - auto body = std::make_unique(); - body->statements.reserve(3); - - auto accessIndex = ShaderBuilder::AccessIndex(std::move(expr), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, node.sourceLocation)); - Validate(*accessIndex); - - auto elementVariable = ShaderBuilder::DeclareVariable(node.varName, std::move(accessIndex)); - elementVariable->varIndex = node.varIndex; //< Preserve var index - Validate(*elementVariable); - body->statements.emplace_back(std::move(elementVariable)); - - { - bool wasInLoop = m_context->inLoop; - m_context->inLoop = true; - Nz::CallOnExit restoreLoop([=] { m_context->inLoop = wasInLoop; }); - - body->statements.emplace_back(Unscope(CloneStatement(node.statement))); - } - - auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, node.sourceLocation), ShaderBuilder::ConstantValue(1u, node.sourceLocation)); - Validate(*incrCounter); - - body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); - - whileStatement->body = std::move(body); - - multi->statements.emplace_back(std::move(whileStatement)); - } + auto clone = std::make_unique(); + clone->expression = std::move(expr); + clone->varName = node.varName; + clone->unroll = std::move(unrollValue); + clone->sourceLocation = node.sourceLocation; - return ShaderBuilder::Scoped(std::move(multi)); - } - else + PushScope(); { - auto clone = std::make_unique(); - clone->expression = std::move(expr); - clone->varName = node.varName; - clone->unroll = std::move(unrollValue); - clone->sourceLocation = node.sourceLocation; - - PushScope(); - { - clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex, node.sourceLocation); - SanitizeIdentifier(node.varName, IdentifierScope::Variable); + clone->varIndex = RegisterVariable(node.varName, innerType, node.varIndex, node.sourceLocation); + SanitizeIdentifier(node.varName, IdentifierScope::Variable); - bool wasInLoop = m_context->inLoop; - m_context->inLoop = true; - Nz::CallOnExit restoreLoop([=] { m_context->inLoop = wasInLoop; }); - - clone->statement = CloneStatement(node.statement); - } - PopScope(); + bool wasInLoop = m_context->inLoop; + m_context->inLoop = true; + Nz::CallOnExit restoreLoop([=] { m_context->inLoop = wasInLoop; }); - return clone; + clone->statement = CloneStatement(node.statement); } + PopScope(); + + return clone; } StatementPtr SanitizeVisitor::Clone(ImportStatement& node) @@ -4371,12 +3967,6 @@ namespace nzsl::Ast { ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), UnwrapExternalType(ResolveAlias(*rightExprType)), node.sourceLocation); TypeMustMatch(UnwrapExternalType(*leftExprType), expressionType, node.sourceLocation); - - if (m_context->options.removeCompoundAssignments) - { - node.op = AssignType::Simple; - node.right = Clone(*ShaderBuilder::Binary(*binaryType, Cloner::Clone(*node.left), std::move(node.right))); - } } node.cachedExpressionType = *leftExprType; @@ -5556,9 +5146,6 @@ namespace nzsl::Ast std::size_t componentCount; if (IsPrimitiveType(resolvedExprType)) { - if (m_context->options.removeScalarSwizzling) - throw AstInternalError{ node.sourceLocation, "scalar swizzling should have been removed before validating" }; - baseType = std::get(resolvedExprType); componentCount = 1; } diff --git a/src/NZSL/GlslWriter.cpp b/src/NZSL/GlslWriter.cpp index 3abe780..0a4eb63 100644 --- a/src/NZSL/GlslWriter.cpp +++ b/src/NZSL/GlslWriter.cpp @@ -504,11 +504,11 @@ namespace nzsl Ast::SanitizeVisitor::Options options; options.makeVariableNameUnique = true; - options.reduceLoopsToWhile = true; + //options.reduceLoopsToWhile = true; options.removeAliases = true; - options.removeCompoundAssignments = false; + //options.removeCompoundAssignments = false; options.removeOptionDeclaration = true; - options.removeScalarSwizzling = true; + //options.removeScalarSwizzling = true; options.removeSingleConstDeclaration = true; options.splitWrappedStructAssignation = true; //< TODO: Only split for base uniforms/storage options.identifierSanitizer = [](std::string& identifier, Ast::IdentifierScope /*scope*/) diff --git a/src/NZSL/SpirvWriter.cpp b/src/NZSL/SpirvWriter.cpp index d84e206..10b5997 100644 --- a/src/NZSL/SpirvWriter.cpp +++ b/src/NZSL/SpirvWriter.cpp @@ -835,16 +835,16 @@ namespace nzsl Ast::SanitizeVisitor::Options SpirvWriter::GetSanitizeOptions() { Ast::SanitizeVisitor::Options options; - options.reduceLoopsToWhile = true; + //options.reduceLoopsToWhile = true; options.removeAliases = true; - options.removeCompoundAssignments = true; + //options.removeCompoundAssignments = true; options.removeConstArraySize = true; - options.removeMatrixBinaryAddSub = true; - options.removeMatrixCast = true; + //options.removeMatrixBinaryAddSub = true; + //options.removeMatrixCast = true; options.removeOptionDeclaration = true; options.removeSingleConstDeclaration = true; options.splitWrappedArrayAssignation = true; - options.splitMultipleBranches = true; + //options.splitMultipleBranches = true; options.splitWrappedStructAssignation = true; options.useIdentifierAccessesForStructs = false; diff --git a/tests/src/Tests/SanitizationsTests.cpp b/tests/src/Tests/SanitizationsTests.cpp index 13d47a3..e9b4189 100644 --- a/tests/src/Tests/SanitizationsTests.cpp +++ b/tests/src/Tests/SanitizationsTests.cpp @@ -4,6 +4,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -44,10 +49,10 @@ fn main() nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); - nzsl::Ast::SanitizeVisitor::Options options; - options.splitMultipleBranches = true; + nzsl::Ast::BranchSplitterTransformer branchSplitterTransformer; + nzsl::Ast::Transformer::Context context; - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + REQUIRE_NOTHROW(branchSplitterTransformer.Transform(*shaderModule, context)); ExpectNZSL(*shaderModule, R"( [entry(frag)] @@ -79,6 +84,64 @@ fn main() } +} +)"); + + } + + WHEN("reducing for to while") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +struct inputStruct +{ + value: array[f32, 10] +} + +external +{ + [set(0), binding(0)] data: uniform[inputStruct] +} + +[entry(frag)] +fn main() +{ + let x = 0.0; + for i in 0 -> 10 + { + x += data.value[i]; + } +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); + + nzsl::Ast::ForToWhileTransformer forToWhileTransformer; + nzsl::Ast::Transformer::Context context; + context.nextVariableIndex = 10; + + REQUIRE_NOTHROW(forToWhileTransformer.Transform(*shaderModule, context)); + + ExpectNZSL(*shaderModule, R"( +[entry(frag)] +fn main() +{ + let x: f32 = 0.0; + { + let i = 0; + let _nzsl_to: i32 = 10; + while (i < _nzsl_to) + { + x += data.value[i]; + i += 1; + } + + } + } )"); @@ -113,10 +176,13 @@ fn main() nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); - nzsl::Ast::SanitizeVisitor::Options options; - options.reduceLoopsToWhile = true; + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + nzsl::Ast::ForToWhileTransformer forToWhileTransformer; + nzsl::Ast::Transformer::Context context; + context.nextVariableIndex = 10; + + REQUIRE_NOTHROW(forToWhileTransformer.Transform(*shaderModule, context)); ExpectNZSL(*shaderModule, R"( [entry(frag)] @@ -124,12 +190,12 @@ fn main() { let x: f32 = 0.0; { - let i: u32 = u32(0); - while (i < (u32(10))) + let _nzsl_counter: u32 = u32(0); + while (_nzsl_counter < (u32(10))) { - let v: f32 = data.value[i]; + let v = data.value[_nzsl_counter]; x += v; - i += u32(1); + _nzsl_counter += u32(1); } } @@ -139,6 +205,46 @@ fn main() } + WHEN("removing compound assignment") + { + std::string_view nzslSource = R"( +[nzsl_version("1.0")] +module; + +fn main() +{ + let x = 1; + let y = 2; + x += y; + x += 1; +} +)"; + + nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); + + nzsl::Ast::AssignmentTransformer::Options options; + options.removeCompoundAssignment = true; + + nzsl::Ast::AssignmentTransformer assignmentTransformer; + nzsl::Ast::Transformer::Context context; + + REQUIRE_NOTHROW(assignmentTransformer.Transform(*shaderModule, context, options)); + + ExpectNZSL(*shaderModule, R"( +[nzsl_version("1.0")] +module; + +fn main() +{ + let x = 1; + let y = 2; + x = x + y; + x = x + (1); +} +)"); + + } + WHEN("removing matrix casts") { std::string_view nzslSource = R"( @@ -198,18 +304,24 @@ fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32] nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); - nzsl::Ast::SanitizeVisitor::Options options; - options.removeMatrixCast = true; + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + nzsl::Ast::Transformer::Context context; + context.nextVariableIndex = 20; + + nzsl::Ast::MatrixTransformer::Options matrixOptions; + matrixOptions.removeMatrixCast = true; + + nzsl::Ast::MatrixTransformer matrixTransformer; + REQUIRE_NOTHROW(matrixTransformer.Transform(*shaderModule, context, matrixOptions)); ExpectNZSL(*shaderModule, R"( fn buildMat2x3(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32) -> mat2x3[f32] { - let temp: mat2x3[f32]; - temp[u32(0)] = vec3[f32](a, b, c); - temp[u32(1)] = vec3[f32](d, e, f); - return temp; + let _nzsl_matrix: mat2x3[f32]; + _nzsl_matrix[u32(0)] = vec3[f32](a, b, c); + _nzsl_matrix[u32(1)] = vec3[f32](d, e, f); + return _nzsl_matrix; } fn testMat2ToMat2(input: mat2[f32]) -> mat2[f32] @@ -219,29 +331,29 @@ fn testMat2ToMat2(input: mat2[f32]) -> mat2[f32] fn testMat2ToMat3(input: mat2[f32]) -> mat3[f32] { - let temp: mat3[f32]; - temp[u32(0)] = vec3[f32](input[u32(0)], 0.0); - temp[u32(1)] = vec3[f32](input[u32(1)], 0.0); - temp[u32(2)] = vec3[f32](input[u32(2)], 1.0); - return temp; + let _nzsl_matrix: mat3[f32]; + _nzsl_matrix[u32(0)] = vec3[f32](input[u32(0)], 0.0); + _nzsl_matrix[u32(1)] = vec3[f32](input[u32(1)], 0.0); + _nzsl_matrix[u32(2)] = vec3[f32](input[u32(2)], 1.0); + return _nzsl_matrix; } fn testMat2ToMat4(input: mat2[f32]) -> mat4[f32] { - let temp: mat4[f32]; - temp[u32(0)] = vec4[f32](input[u32(0)], 0.0, 0.0); - temp[u32(1)] = vec4[f32](input[u32(1)], 0.0, 0.0); - temp[u32(2)] = vec4[f32](input[u32(2)], 1.0, 0.0); - temp[u32(3)] = vec4[f32](input[u32(3)], 0.0, 1.0); - return temp; + let _nzsl_matrix: mat4[f32]; + _nzsl_matrix[u32(0)] = vec4[f32](input[u32(0)], 0.0, 0.0); + _nzsl_matrix[u32(1)] = vec4[f32](input[u32(1)], 0.0, 0.0); + _nzsl_matrix[u32(2)] = vec4[f32](input[u32(2)], 1.0, 0.0); + _nzsl_matrix[u32(3)] = vec4[f32](input[u32(3)], 0.0, 1.0); + return _nzsl_matrix; } fn testMat3ToMat2(input: mat3[f32]) -> mat2[f32] { - let temp: mat2[f32]; - temp[u32(0)] = input[u32(0)].xy; - temp[u32(1)] = input[u32(1)].xy; - return temp; + let _nzsl_matrix: mat2[f32]; + _nzsl_matrix[u32(0)] = input[u32(0)].xy; + _nzsl_matrix[u32(1)] = input[u32(1)].xy; + return _nzsl_matrix; } fn testMat3ToMat3(input: mat3[f32]) -> mat3[f32] @@ -251,29 +363,29 @@ fn testMat3ToMat3(input: mat3[f32]) -> mat3[f32] fn testMat3ToMat4(input: mat3[f32]) -> mat4[f32] { - let temp: mat4[f32]; - temp[u32(0)] = vec4[f32](input[u32(0)], 0.0); - temp[u32(1)] = vec4[f32](input[u32(1)], 0.0); - temp[u32(2)] = vec4[f32](input[u32(2)], 0.0); - temp[u32(3)] = vec4[f32](input[u32(3)], 1.0); - return temp; + let _nzsl_matrix: mat4[f32]; + _nzsl_matrix[u32(0)] = vec4[f32](input[u32(0)], 0.0); + _nzsl_matrix[u32(1)] = vec4[f32](input[u32(1)], 0.0); + _nzsl_matrix[u32(2)] = vec4[f32](input[u32(2)], 0.0); + _nzsl_matrix[u32(3)] = vec4[f32](input[u32(3)], 1.0); + return _nzsl_matrix; } fn testMat4ToMat2(input: mat4[f32]) -> mat2[f32] { - let temp: mat2[f32]; - temp[u32(0)] = input[u32(0)].xy; - temp[u32(1)] = input[u32(1)].xy; - return temp; + let _nzsl_matrix: mat2[f32]; + _nzsl_matrix[u32(0)] = input[u32(0)].xy; + _nzsl_matrix[u32(1)] = input[u32(1)].xy; + return _nzsl_matrix; } fn testMat4ToMat3(input: mat4[f32]) -> mat3[f32] { - let temp: mat3[f32]; - temp[u32(0)] = input[u32(0)].xyz; - temp[u32(1)] = input[u32(1)].xyz; - temp[u32(2)] = input[u32(2)].xyz; - return temp; + let _nzsl_matrix: mat3[f32]; + _nzsl_matrix[u32(0)] = input[u32(0)].xyz; + _nzsl_matrix[u32(1)] = input[u32(1)].xyz; + _nzsl_matrix[u32(2)] = input[u32(2)].xyz; + return _nzsl_matrix; } fn testMat4ToMat4(input: mat4[f32]) -> mat4[f32] @@ -320,11 +432,21 @@ fn testMat4CompoundMinusMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); - nzsl::Ast::SanitizeVisitor::Options options; - options.removeMatrixBinaryAddSub = true; - options.removeCompoundAssignments = true; + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + nzsl::Ast::Transformer::Context context; + + nzsl::Ast::AssignmentTransformer::Options assignmentOptions; + assignmentOptions.removeCompoundAssignment = true; + + nzsl::Ast::AssignmentTransformer assignmentTransformer; + REQUIRE_NOTHROW(assignmentTransformer.Transform(*shaderModule, context, assignmentOptions)); + + nzsl::Ast::MatrixTransformer::Options matrixOptions; + matrixOptions.removeMatrixBinaryAddSub = true; + + nzsl::Ast::MatrixTransformer matrixTransformer; + REQUIRE_NOTHROW(matrixTransformer.Transform(*shaderModule, context, matrixOptions)); ExpectNZSL(*shaderModule, R"( fn testMat4PlusMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] @@ -339,8 +461,8 @@ fn testMat4SubMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] fn testMat4SubMat4TimesMat4(x: mat4[f32], y: mat4[f32], z: mat4[f32]) -> mat4[f32] { - let cachedResult: mat4[f32] = y * y; - return mat4[f32](x[u32(0)] - cachedResult[u32(0)], x[u32(1)] - cachedResult[u32(1)], x[u32(2)] - cachedResult[u32(2)], x[u32(3)] - cachedResult[u32(3)]); + let _nzsl_cachedResult: mat4[f32] = y * y; + return mat4[f32](x[u32(0)] - _nzsl_cachedResult[u32(0)], x[u32(1)] - _nzsl_cachedResult[u32(1)], x[u32(2)] - _nzsl_cachedResult[u32(2)], x[u32(3)] - _nzsl_cachedResult[u32(3)]); } fn testMat4CompoundPlusMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] @@ -358,40 +480,47 @@ fn testMat4CompoundMinusMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] WHEN("Removing matrix casts") { - options.removeMatrixCast = true; + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + nzsl::Ast::Transformer::Context context2; + context2.nextVariableIndex = 20; + + nzsl::Ast::MatrixTransformer::Options matrixOptions2; + matrixOptions2.removeMatrixCast = true; + + nzsl::Ast::MatrixTransformer matrixTransformer2; + REQUIRE_NOTHROW(matrixTransformer2.Transform(*shaderModule, context2, matrixOptions2)); ExpectNZSL(*shaderModule, R"( fn testMat4PlusMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] { - let temp: mat4[f32]; - temp[u32(0)] = x[u32(0)] + y[u32(0)]; - temp[u32(1)] = x[u32(1)] + y[u32(1)]; - temp[u32(2)] = x[u32(2)] + y[u32(2)]; - temp[u32(3)] = x[u32(3)] + y[u32(3)]; - return temp; + let _nzsl_matrix: mat4[f32]; + _nzsl_matrix[u32(0)] = x[u32(0)] + y[u32(0)]; + _nzsl_matrix[u32(1)] = x[u32(1)] + y[u32(1)]; + _nzsl_matrix[u32(2)] = x[u32(2)] + y[u32(2)]; + _nzsl_matrix[u32(3)] = x[u32(3)] + y[u32(3)]; + return _nzsl_matrix; } fn testMat4SubMat4(x: mat4[f32], y: mat4[f32]) -> mat4[f32] { - let temp: mat4[f32]; - temp[u32(0)] = x[u32(0)] - y[u32(0)]; - temp[u32(1)] = x[u32(1)] - y[u32(1)]; - temp[u32(2)] = x[u32(2)] - y[u32(2)]; - temp[u32(3)] = x[u32(3)] - y[u32(3)]; - return temp; + let _nzsl_matrix: mat4[f32]; + _nzsl_matrix[u32(0)] = x[u32(0)] - y[u32(0)]; + _nzsl_matrix[u32(1)] = x[u32(1)] - y[u32(1)]; + _nzsl_matrix[u32(2)] = x[u32(2)] - y[u32(2)]; + _nzsl_matrix[u32(3)] = x[u32(3)] - y[u32(3)]; + return _nzsl_matrix; } fn testMat4SubMat4TimesMat4(x: mat4[f32], y: mat4[f32], z: mat4[f32]) -> mat4[f32] { - let cachedResult: mat4[f32] = y * y; - let temp: mat4[f32]; - temp[u32(0)] = x[u32(0)] - cachedResult[u32(0)]; - temp[u32(1)] = x[u32(1)] - cachedResult[u32(1)]; - temp[u32(2)] = x[u32(2)] - cachedResult[u32(2)]; - temp[u32(3)] = x[u32(3)] - cachedResult[u32(3)]; - return temp; + let _nzsl_cachedResult: mat4[f32] = y * y; + let _nzsl_matrix: mat4[f32]; + _nzsl_matrix[u32(0)] = x[u32(0)] - _nzsl_cachedResult[u32(0)]; + _nzsl_matrix[u32(1)] = x[u32(1)] - _nzsl_cachedResult[u32(1)]; + _nzsl_matrix[u32(2)] = x[u32(2)] - _nzsl_cachedResult[u32(2)]; + _nzsl_matrix[u32(3)] = x[u32(3)] - _nzsl_cachedResult[u32(3)]; + return _nzsl_matrix; } )"); } @@ -444,27 +573,49 @@ external [nzsl_version("1.0")] module; +fn expr() -> i32 +{ + return 1.0; +} + fn main() { let value = 42.0; - let y = value.r; - let z = value.xxxx; + let x = value.r; + let y = value.xxxx; + let z = expr().xxx; } )"; nzsl::Ast::ModulePtr shaderModule = nzsl::Parse(nzslSource); - nzsl::Ast::SanitizeVisitor::Options options; - options.removeScalarSwizzling = true; + REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule)); - REQUIRE_NOTHROW(shaderModule = nzsl::Ast::Sanitize(*shaderModule, options)); + nzsl::Ast::SwizzleTransformer::Options swizzleOptions; + swizzleOptions.removeScalarSwizzling = true; + + nzsl::Ast::SwizzleTransformer swizzleTransformer; + nzsl::Ast::Transformer::Context context; + context.nextVariableIndex = 10; + + REQUIRE_NOTHROW(swizzleTransformer.Transform(*shaderModule, context, swizzleOptions)); ExpectNZSL(*shaderModule, R"( +[nzsl_version("1.0")] +module; + +fn expr() -> i32 +{ + return 1.0; +} + fn main() { let value: f32 = 42.0; - let y: f32 = value; - let z: vec4[f32] = vec4[f32](value, value, value, value); + let x: f32 = value; + let y: vec4[f32] = vec4[f32](value, value, value, value); + let _nzsl_cachedResult: i32 = expr(); + let z: vec3[i32] = vec3[i32](_nzsl_cachedResult, _nzsl_cachedResult, _nzsl_cachedResult); } )");