diff --git a/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp b/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp new file mode 100644 index 0000000..eeb0022 --- /dev/null +++ b/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API ForToWhileTransformer final : public StatementTransformer + { + public: + struct Options; + + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); + bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool allowPartialSanitization = false; + bool reduceForEachLoopsToWhile = true; + bool reduceForLoopsToWhile = true; + }; + + private: + using StatementTransformer::Transform; + StatementPtr Transform(ForEachStatement&& statement) override; + StatementPtr Transform(ForStatement&& statement) override; + + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl b/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl new file mode 100644 index 0000000..b31c0a8 --- /dev/null +++ b/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl @@ -0,0 +1,11 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline bool ForToWhileTransformer::Transform(Module& module, Context& context, std::string* error) + { + return Transform(module, context, {}, error); + } +} diff --git a/include/NZSL/Ast/Transformations/StatementTransformer.hpp b/include/NZSL/Ast/Transformations/StatementTransformer.hpp new file mode 100644 index 0000000..4949814 --- /dev/null +++ b/include/NZSL/Ast/Transformations/StatementTransformer.hpp @@ -0,0 +1,64 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP + +#include +#include +#include + +namespace nzsl::Ast +{ + class NZSL_API StatementTransformer : public StatementVisitor + { + public: + struct Context + { + std::size_t nextVariableIndex; + }; + + protected: +#define NZSL_SHADERAST_STATEMENT(Node) virtual StatementPtr Transform(Node##Statement&& statement); +#include + + bool TransformModule(Module& module, Context& context, std::string* error = nullptr); + void TransformStatement(StatementPtr& statement); + + Context* m_context; + + private: + template bool TransformCurrent(); + + void Visit(BranchStatement& node) override; + void Visit(BreakStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(ContinueStatement& node) override; + void Visit(DeclareAliasStatement& node) override; + void Visit(DeclareConstStatement& node) override; + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareOptionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(DiscardStatement& node) override; + void Visit(ExpressionStatement& node) override; + void Visit(ForStatement& node) override; + void Visit(ForEachStatement& node) override; + void Visit(ImportStatement& node) override; + void Visit(MultiStatement& node) override; + void Visit(NoOpStatement& node) override; + void Visit(ReturnStatement& node) override; + void Visit(ScopedStatement& node) override; + void Visit(WhileStatement& node) override; + + std::vector m_statementStack; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/StatementTransformer.inl b/include/NZSL/Ast/Transformations/StatementTransformer.inl new file mode 100644 index 0000000..2ce0972 --- /dev/null +++ b/include/NZSL/Ast/Transformations/StatementTransformer.inl @@ -0,0 +1,7 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ +} diff --git a/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp b/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp new file mode 100644 index 0000000..109f2a9 --- /dev/null +++ b/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp @@ -0,0 +1,185 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace nzsl::Ast +{ + bool ForToWhileTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) + { + m_options = &options; + + return TransformModule(module, context, error); + } + + StatementPtr ForToWhileTransformer::Transform(ForEachStatement&& forEachStatement) + { + if (!m_options->reduceForEachLoopsToWhile) + return nullptr; + + const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression); + if (!exprType) + { + if (!m_options->allowPartialSanitization) + throw AstInternalError{ forEachStatement.sourceLocation, "unexpected missing expression type" }; + + return nullptr; + } + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + + ExpressionType innerType; + if (IsArrayType(resolvedExprType)) + { + const ArrayType& arrayType = std::get(resolvedExprType); + innerType = arrayType.containedType->type; + } + else + throw CompilerForEachUnsupportedTypeError{ forEachStatement.sourceLocation, ToString(*exprType) }; + + auto multi = std::make_unique(); + multi->sourceLocation = forEachStatement.sourceLocation; + + if (IsArrayType(resolvedExprType)) + { + const ArrayType& arrayType = std::get(resolvedExprType); + + multi->statements.reserve(2); + + // Counter variable + auto counterVariable = ShaderBuilder::DeclareVariable("_nzsl_counter", ShaderBuilder::ConstantValue(0u)); + counterVariable->sourceLocation = forEachStatement.sourceLocation; + counterVariable->varIndex = m_context->nextVariableIndex++; + + std::size_t counterVarIndex = counterVariable->varIndex.value(); + + multi->statements.emplace_back(std::move(counterVariable)); + + auto whileStatement = std::make_unique(); + whileStatement->unroll = std::move(forEachStatement.unroll); + + // While condition + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(arrayType.length, forEachStatement.sourceLocation)); + whileStatement->condition = std::move(condition); + + // While body + auto body = std::make_unique(); + body->statements.reserve(3); + + auto accessIndex = ShaderBuilder::AccessIndex(std::move(forEachStatement.expression), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation)); + + auto elementVariable = ShaderBuilder::DeclareVariable(forEachStatement.varName, std::move(accessIndex)); + elementVariable->varIndex = forEachStatement.varIndex; //< Preserve var index + + body->statements.emplace_back(std::move(elementVariable)); + body->statements.emplace_back(std::move(forEachStatement.statement)); + + auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(1u, forEachStatement.sourceLocation)); + + body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); + + whileStatement->body = std::move(body); + + multi->statements.emplace_back(std::move(whileStatement)); + } + + return ShaderBuilder::Scoped(std::move(multi)); + } + + StatementPtr ForToWhileTransformer::Transform(ForStatement&& forStatement) + { + if (!m_options->reduceForLoopsToWhile) + return nullptr; + + Expression& fromExpr = *forStatement.fromExpr; + const ExpressionType* fromExprType = GetExpressionType(fromExpr); + if (!fromExprType) + { + if (!m_options->allowPartialSanitization) + throw AstInternalError{ forStatement.sourceLocation, "unexpected missing expression type" }; + + return nullptr; + } + + const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType); + if (!IsPrimitiveType(resolvedFromExprType)) + throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) }; + + PrimitiveType counterType = std::get(resolvedFromExprType); + if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32) + throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) }; + + auto multi = std::make_unique(); + multi->sourceLocation = forStatement.sourceLocation; + + // Counter variable + auto counterVariable = ShaderBuilder::DeclareVariable(forStatement.varName, std::move(forStatement.fromExpr)); + counterVariable->sourceLocation = forStatement.sourceLocation; + counterVariable->varIndex = forStatement.varIndex; + + std::size_t counterVarIndex = counterVariable->varIndex.value(); + multi->statements.emplace_back(std::move(counterVariable)); + + // Target variable + auto targetVariable = ShaderBuilder::DeclareVariable("_nzsl_to", std::move(forStatement.toExpr)); + targetVariable->sourceLocation = forStatement.sourceLocation; + targetVariable->varIndex = m_context->nextVariableIndex++; + + std::size_t targetVarIndex = targetVariable->varIndex.value(); + multi->statements.emplace_back(std::move(targetVariable)); + + // Step variable + std::optional stepVarIndex; + + if (forStatement.stepExpr) + { + auto stepVariable = ShaderBuilder::DeclareVariable("_nzsl_step", std::move(forStatement.stepExpr)); + stepVariable->sourceLocation = forStatement.sourceLocation; + stepVariable->varIndex = m_context->nextVariableIndex++; + + stepVarIndex = stepVariable->varIndex; + multi->statements.emplace_back(std::move(stepVariable)); + } + + // While + auto whileStatement = std::make_unique(); + whileStatement->sourceLocation = forStatement.sourceLocation; + whileStatement->unroll = std::move(forStatement.unroll); + + // While condition + auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation); + auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType, forStatement.sourceLocation); + + auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable)); + condition->sourceLocation = forStatement.sourceLocation; + + whileStatement->condition = std::move(condition); + + // While body + auto body = std::make_unique(); + body->statements.reserve(2); + + // Counter and increment + ExpressionPtr incrExpr; + if (stepVarIndex) + incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType, forStatement.sourceLocation); + else + incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::ConstantValue(1, forStatement.sourceLocation) : ShaderBuilder::ConstantValue(1u, forStatement.sourceLocation); + + incrExpr->sourceLocation = forStatement.sourceLocation; + + auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation), std::move(incrExpr)); + incrCounter->sourceLocation = forStatement.sourceLocation; + + body->statements.emplace_back(std::move(forStatement.statement)); + body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); + + whileStatement->body = std::move(body); + + multi->statements.emplace_back(std::move(whileStatement)); + + return ShaderBuilder::Scoped(std::move(multi)); + } +} diff --git a/src/NZSL/Ast/Transformations/StatementTransformer.cpp b/src/NZSL/Ast/Transformations/StatementTransformer.cpp new file mode 100644 index 0000000..fbd9c48 --- /dev/null +++ b/src/NZSL/Ast/Transformations/StatementTransformer.cpp @@ -0,0 +1,194 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include + +namespace nzsl::Ast +{ +#define NZSL_SHADERAST_STATEMENT(Node) \ + StatementPtr StatementTransformer::Transform(Node##Statement&& /*statement*/) \ + { \ + return nullptr; \ + } + +#include + + bool StatementTransformer::TransformModule(Module& module, Context& context, std::string* error) + { + m_context = &context; + + try + { + StatementPtr root = std::move(module.rootNode); + TransformStatement(root); + module.rootNode = Nz::StaticUniquePointerCast(std::move(root)); + } + catch(const std::exception& e) + { + if (!error) + throw; + + *error = e.what(); + return false; + } + + return true; + } + + void StatementTransformer::TransformStatement(StatementPtr& statement) + { + assert(statement); + + m_statementStack.push_back(&statement); + statement->Visit(*this); + m_statementStack.pop_back(); + } + + template + bool StatementTransformer::TransformCurrent() + { + StatementPtr newStatement = Transform(std::move(Nz::SafeCast(**m_statementStack.back()))); + if (!newStatement) + return false; + + *m_statementStack.back() = std::move(newStatement); + return true; + } + + void StatementTransformer::Visit(BranchStatement& node) + { + if (TransformCurrent()) + return; + + for (auto& cond : node.condStatements) + TransformStatement(cond.statement); + + if (node.elseStatement) + TransformStatement(node.elseStatement); + } + + void StatementTransformer::Visit(BreakStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ConditionalStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ContinueStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareAliasStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareConstStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareExternalStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareFunctionStatement& node) + { + if (TransformCurrent()) + return; + + for (auto& statement : node.statements) + TransformStatement(statement); + } + + void StatementTransformer::Visit(DeclareOptionStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareStructStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DeclareVariableStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(DiscardStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ExpressionStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ForStatement& node) + { + if (TransformCurrent()) + return; + + if (node.statement) + TransformStatement(node.statement); + } + + void StatementTransformer::Visit(ForEachStatement& node) + { + if (TransformCurrent()) + return; + + if (node.statement) + TransformStatement(node.statement); + } + + void StatementTransformer::Visit(ImportStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(MultiStatement& node) + { + if (TransformCurrent()) + return; + + for (auto& statement : node.statements) + TransformStatement(statement); + } + + void StatementTransformer::Visit(NoOpStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ReturnStatement& /*node*/) + { + TransformCurrent(); + } + + void StatementTransformer::Visit(ScopedStatement& node) + { + if (TransformCurrent()) + return; + + if (node.statement) + TransformStatement(node.statement); + } + + void StatementTransformer::Visit(WhileStatement& node) + { + if (TransformCurrent()) + return; + + if (node.body) + TransformStatement(node.body); + } +}