diff --git a/include/NZSL/Ast/Transformations/AssignmentTransformer.hpp b/include/NZSL/Ast/Transformations/AssignmentTransformer.hpp new file mode 100644 index 0000000..6c1394c --- /dev/null +++ b/include/NZSL/Ast/Transformations/AssignmentTransformer.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_ASSIGNMENTTRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_ASSIGNMENTTRANSFORMER_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API AssignmentTransformer final : public Transformer + { + public: + struct Options; + + inline AssignmentTransformer(); + + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); + bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool allowPartialSanitization = false; + bool removeCompoundAssignment = false; + }; + + private: + using Transformer::Transform; + ExpressionPtr Transform(AssignExpression&& assign) override; + + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_ASSIGNMENTTRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/AssignmentTransformer.inl b/include/NZSL/Ast/Transformations/AssignmentTransformer.inl new file mode 100644 index 0000000..92cbee3 --- /dev/null +++ b/include/NZSL/Ast/Transformations/AssignmentTransformer.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline AssignmentTransformer::AssignmentTransformer() : + Transformer(true) + { + } + + inline bool AssignmentTransformer::Transform(Module& module, Context& context, std::string* error) + { + return Transform(module, context, {}, error); + } +} diff --git a/include/NZSL/Ast/Transformations/BranchSplitterTransformer.hpp b/include/NZSL/Ast/Transformations/BranchSplitterTransformer.hpp new file mode 100644 index 0000000..0117eca --- /dev/null +++ b/include/NZSL/Ast/Transformations/BranchSplitterTransformer.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_BRANCHSPLITTERTRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_BRANCHSPLITTERTRANSFORMER_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API BranchSplitterTransformer final : public Transformer + { + public: + struct Options; + + inline BranchSplitterTransformer(); + + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); + bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool allowPartialSanitization = false; + }; + + private: + using Transformer::Transform; + StatementPtr Transform(BranchStatement&& statement) override; + + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_BRANCHSPLITTERTRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/BranchSplitterTransformer.inl b/include/NZSL/Ast/Transformations/BranchSplitterTransformer.inl new file mode 100644 index 0000000..29b12e8 --- /dev/null +++ b/include/NZSL/Ast/Transformations/BranchSplitterTransformer.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline BranchSplitterTransformer::BranchSplitterTransformer() : + Transformer(false) + { + } + + inline bool BranchSplitterTransformer::Transform(Module& module, Context& context, std::string* error) + { + return Transform(module, context, {}, error); + } +} diff --git a/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp b/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp index eeb0022..257df89 100644 --- a/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp +++ b/include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp @@ -7,15 +7,17 @@ #ifndef NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP #define NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP -#include +#include namespace nzsl::Ast { - class NZSL_API ForToWhileTransformer final : public StatementTransformer + class NZSL_API ForToWhileTransformer final : public Transformer { public: struct Options; + inline ForToWhileTransformer(); + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); @@ -27,7 +29,7 @@ namespace nzsl::Ast }; private: - using StatementTransformer::Transform; + using Transformer::Transform; StatementPtr Transform(ForEachStatement&& statement) override; StatementPtr Transform(ForStatement&& statement) override; diff --git a/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl b/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl index b31c0a8..9f5696d 100644 --- a/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl +++ b/include/NZSL/Ast/Transformations/ForToWhileTransformer.inl @@ -4,6 +4,11 @@ namespace nzsl::Ast { + inline ForToWhileTransformer::ForToWhileTransformer() : + Transformer(false) + { + } + inline bool ForToWhileTransformer::Transform(Module& module, Context& context, std::string* error) { return Transform(module, context, {}, error); diff --git a/include/NZSL/Ast/Transformations/MatrixTransformer.hpp b/include/NZSL/Ast/Transformations/MatrixTransformer.hpp new file mode 100644 index 0000000..d085fe4 --- /dev/null +++ b/include/NZSL/Ast/Transformations/MatrixTransformer.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_MATRIXTRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_MATRIXTRANSFORMER_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API MatrixTransformer final : public Transformer + { + public: + struct Options; + + inline MatrixTransformer(); + + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); + bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool removeMatrixBinaryAddSub = false; + bool removeMatrixCast = false; + }; + + private: + using Transformer::Transform; + + ExpressionPtr Transform(BinaryExpression&& binExpr) override; + ExpressionPtr Transform(CastExpression&& castExpr) override; + + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_MATRIXTRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/StatementTransformer.inl b/include/NZSL/Ast/Transformations/MatrixTransformer.inl similarity index 51% rename from include/NZSL/Ast/Transformations/StatementTransformer.inl rename to include/NZSL/Ast/Transformations/MatrixTransformer.inl index 2ce0972..4b4f293 100644 --- a/include/NZSL/Ast/Transformations/StatementTransformer.inl +++ b/include/NZSL/Ast/Transformations/MatrixTransformer.inl @@ -4,4 +4,13 @@ namespace nzsl::Ast { + inline MatrixTransformer::MatrixTransformer() : + Transformer(true) + { + } + + inline bool MatrixTransformer::Transform(Module& module, Context& context, std::string* error) + { + return Transform(module, context, {}, error); + } } diff --git a/include/NZSL/Ast/Transformations/StatementTransformer.hpp b/include/NZSL/Ast/Transformations/StatementTransformer.hpp deleted file mode 100644 index 4949814..0000000 --- a/include/NZSL/Ast/Transformations/StatementTransformer.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) -// This file is part of the "Nazara Shading Language" project -// For conditions of distribution and use, see copyright notice in Config.hpp - -#pragma once - -#ifndef NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP -#define NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP - -#include -#include -#include - -namespace nzsl::Ast -{ - class NZSL_API StatementTransformer : public StatementVisitor - { - public: - struct Context - { - std::size_t nextVariableIndex; - }; - - protected: -#define NZSL_SHADERAST_STATEMENT(Node) virtual StatementPtr Transform(Node##Statement&& statement); -#include - - bool TransformModule(Module& module, Context& context, std::string* error = nullptr); - void TransformStatement(StatementPtr& statement); - - Context* m_context; - - private: - template bool TransformCurrent(); - - void Visit(BranchStatement& node) override; - void Visit(BreakStatement& node) override; - void Visit(ConditionalStatement& node) override; - void Visit(ContinueStatement& node) override; - void Visit(DeclareAliasStatement& node) override; - void Visit(DeclareConstStatement& node) override; - void Visit(DeclareExternalStatement& node) override; - void Visit(DeclareFunctionStatement& node) override; - void Visit(DeclareOptionStatement& node) override; - void Visit(DeclareStructStatement& node) override; - void Visit(DeclareVariableStatement& node) override; - void Visit(DiscardStatement& node) override; - void Visit(ExpressionStatement& node) override; - void Visit(ForStatement& node) override; - void Visit(ForEachStatement& node) override; - void Visit(ImportStatement& node) override; - void Visit(MultiStatement& node) override; - void Visit(NoOpStatement& node) override; - void Visit(ReturnStatement& node) override; - void Visit(ScopedStatement& node) override; - void Visit(WhileStatement& node) override; - - std::vector m_statementStack; - }; -} - -#include - -#endif // NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp b/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp new file mode 100644 index 0000000..409804d --- /dev/null +++ b/include/NZSL/Ast/Transformations/SwizzleTransformer.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_SWIZZLETRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_SWIZZLETRANSFORMER_HPP + +#include + +namespace nzsl::Ast +{ + class NZSL_API SwizzleTransformer final : public Transformer + { + public: + struct Options; + + inline SwizzleTransformer(); + + inline bool Transform(Module& module, Context& context, std::string* error = nullptr); + bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); + + struct Options + { + bool allowPartialSanitization = false; + bool removeScalarSwizzling = false; + }; + + private: + using Transformer::Transform; + ExpressionPtr Transform(SwizzleExpression&& swizzle) override; + + const Options* m_options; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_SWIZZLETRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/SwizzleTransformer.inl b/include/NZSL/Ast/Transformations/SwizzleTransformer.inl new file mode 100644 index 0000000..dc8da2c --- /dev/null +++ b/include/NZSL/Ast/Transformations/SwizzleTransformer.inl @@ -0,0 +1,16 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline SwizzleTransformer::SwizzleTransformer() : + Transformer(true) + { + } + + inline bool SwizzleTransformer::Transform(Module& module, Context& context, std::string* error) + { + return Transform(module, context, {}, error); + } +} diff --git a/include/NZSL/Ast/Transformations/Transformer.hpp b/include/NZSL/Ast/Transformations/Transformer.hpp new file mode 100644 index 0000000..9099f69 --- /dev/null +++ b/include/NZSL/Ast/Transformations/Transformer.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NZSL_AST_TRANSFORMATIONS_TRANSFORMER_HPP +#define NZSL_AST_TRANSFORMATIONS_TRANSFORMER_HPP + +#include +#include +#include +#include + +namespace nzsl::Ast +{ + class NZSL_API Transformer : public ExpressionVisitor, public StatementVisitor + { + public: + struct Context + { + std::size_t nextVariableIndex; + bool allowPartialSanitization = false; + }; + + static StatementPtr Unscope(StatementPtr&& statement); + + protected: + inline Transformer(bool visitExpressions); + + void AppendStatement(StatementPtr statement); + + ExpressionPtr CacheExpression(ExpressionPtr expression); + + DeclareVariableStatement* DeclareVariable(std::string_view name, ExpressionPtr initialExpr); + DeclareVariableStatement* DeclareVariable(std::string_view name, Ast::ExpressionType type, SourceLocation sourceLocation); + + const ExpressionType* GetExpressionType(Expression& expr) const; + const ExpressionType* GetExpressionType(Expression& expr, bool allowEmpty) const; + + template void HandleStatementList(std::vector& statementList, F&& callback); + +#define NZSL_SHADERAST_NODE(Node, Type) virtual Type##Ptr Transform(Node##Type&& node); +#include + + void TransformExpression(ExpressionPtr& expression); + bool TransformModule(Module& module, Context& context, std::string* error = nullptr); + void TransformStatement(StatementPtr& statement); + + Context* m_context; + + private: + template bool TransformCurrentExpression(); + template bool TransformCurrentStatement(); + + void Visit(AccessIdentifierExpression& node) override; + void Visit(AccessIndexExpression& node) override; + void Visit(AliasValueExpression& node) override; + void Visit(AssignExpression& node) override; + void Visit(BinaryExpression& node) override; + void Visit(CallFunctionExpression& node) override; + void Visit(CallMethodExpression& node) override; + void Visit(CastExpression& node) override; + void Visit(ConditionalExpression& node) override; + void Visit(ConstantExpression& node) override; + void Visit(ConstantArrayValueExpression& node) override; + void Visit(ConstantValueExpression& node) override; + void Visit(FunctionExpression& node) override; + void Visit(IdentifierExpression& node) override; + void Visit(IntrinsicExpression& node) override; + void Visit(IntrinsicFunctionExpression& node) override; + void Visit(StructTypeExpression& node) override; + void Visit(SwizzleExpression& node) override; + void Visit(TypeExpression& node) override; + void Visit(VariableValueExpression& node) override; + void Visit(UnaryExpression& node) override; + + void Visit(BranchStatement& node) override; + void Visit(BreakStatement& node) override; + void Visit(ConditionalStatement& node) override; + void Visit(ContinueStatement& node) override; + void Visit(DeclareAliasStatement& node) override; + void Visit(DeclareConstStatement& node) override; + void Visit(DeclareExternalStatement& node) override; + void Visit(DeclareFunctionStatement& node) override; + void Visit(DeclareOptionStatement& node) override; + void Visit(DeclareStructStatement& node) override; + void Visit(DeclareVariableStatement& node) override; + void Visit(DiscardStatement& node) override; + void Visit(ExpressionStatement& node) override; + void Visit(ForStatement& node) override; + void Visit(ForEachStatement& node) override; + void Visit(ImportStatement& node) override; + void Visit(MultiStatement& node) override; + void Visit(NoOpStatement& node) override; + void Visit(ReturnStatement& node) override; + void Visit(ScopedStatement& node) override; + void Visit(WhileStatement& node) override; + + std::size_t m_currentStatementListIndex; + std::vector m_expressionStack; + std::vector m_statementStack; + std::vector* m_currentStatementList; + bool m_visitExpressions; + }; +} + +#include + +#endif // NZSL_AST_TRANSFORMATIONS_TRANSFORMER_HPP diff --git a/include/NZSL/Ast/Transformations/Transformer.inl b/include/NZSL/Ast/Transformations/Transformer.inl new file mode 100644 index 0000000..34baebf --- /dev/null +++ b/include/NZSL/Ast/Transformations/Transformer.inl @@ -0,0 +1,32 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +namespace nzsl::Ast +{ + inline Transformer::Transformer(bool visitExpressions) : + m_visitExpressions(visitExpressions) + { + } + + template + void Transformer::HandleStatementList(std::vector& statementList, F&& callback) + { + std::vector* previousStatementList = m_currentStatementList; + std::size_t previousListIndex = m_currentStatementListIndex; + + m_currentStatementList = &statementList; + m_currentStatementListIndex = 0; + + if constexpr (Single) + callback(); + else + { + for (; m_currentStatementListIndex < statementList.size(); ++m_currentStatementListIndex) + callback(statementList[m_currentStatementListIndex]); + } + + m_currentStatementList = previousStatementList; + m_currentStatementListIndex = previousListIndex; + } +} diff --git a/src/NZSL/Ast/Transformations/AssignmentTransformer.cpp b/src/NZSL/Ast/Transformations/AssignmentTransformer.cpp new file mode 100644 index 0000000..fd5e1b2 --- /dev/null +++ b/src/NZSL/Ast/Transformations/AssignmentTransformer.cpp @@ -0,0 +1,41 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace nzsl::Ast +{ + bool AssignmentTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) + { + m_options = &options; + + return TransformModule(module, context, error); + } + + ExpressionPtr AssignmentTransformer::Transform(AssignExpression&& assign) + { + if (assign.op == AssignType::Simple || !m_options->removeCompoundAssignment) + return nullptr; + + BinaryType binaryType; + switch (assign.op) + { + case AssignType::Simple: NAZARA_UNREACHABLE(); + case AssignType::CompoundAdd: binaryType = BinaryType::Add; break; + case AssignType::CompoundDivide: binaryType = BinaryType::Divide; break; + case AssignType::CompoundModulo: binaryType = BinaryType::Modulo; break; + case AssignType::CompoundMultiply: binaryType = BinaryType::Multiply; break; + case AssignType::CompoundLogicalAnd: binaryType = BinaryType::LogicalAnd; break; + case AssignType::CompoundLogicalOr: binaryType = BinaryType::LogicalOr; break; + case AssignType::CompoundSubtract: binaryType = BinaryType::Subtract; break; + } + + assign.op = AssignType::Simple; + assign.right = ShaderBuilder::Binary(binaryType, Clone(*assign.left), std::move(assign.right)); + + return nullptr; + } +} diff --git a/src/NZSL/Ast/Transformations/BranchSplitterTransformer.cpp b/src/NZSL/Ast/Transformations/BranchSplitterTransformer.cpp new file mode 100644 index 0000000..f64f35d --- /dev/null +++ b/src/NZSL/Ast/Transformations/BranchSplitterTransformer.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace nzsl::Ast +{ + bool BranchSplitterTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) + { + m_options = &options; + + return TransformModule(module, context, error); + } + + StatementPtr BranchSplitterTransformer::Transform(BranchStatement&& branchStatement) + { + if (branchStatement.condStatements.size() < 2) + return nullptr; + + StatementPtr elseStatement = std::move(branchStatement.elseStatement); + for (std::size_t i = branchStatement.condStatements.size() - 1; i >= 1; --i) + { + auto& condStatement = branchStatement.condStatements[i]; + + SourceLocation sourceLocation = SourceLocation::BuildFromTo(condStatement.condition->sourceLocation, (elseStatement) ? elseStatement->sourceLocation : condStatement.statement->sourceLocation); + + elseStatement = ShaderBuilder::Branch(std::move(condStatement.condition), std::move(condStatement.statement), std::move(elseStatement)); + elseStatement->sourceLocation = std::move(sourceLocation); + } + + branchStatement.condStatements.resize(1); + branchStatement.elseStatement = std::move(elseStatement); + + return nullptr; + } +} diff --git a/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp b/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp index 109f2a9..ed4f659 100644 --- a/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp +++ b/src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp @@ -22,7 +22,7 @@ namespace nzsl::Ast const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression); if (!exprType) { - if (!m_options->allowPartialSanitization) + if (!m_context->allowPartialSanitization) throw AstInternalError{ forEachStatement.sourceLocation, "unexpected missing expression type" }; return nullptr; @@ -49,7 +49,7 @@ namespace nzsl::Ast multi->statements.reserve(2); // Counter variable - auto counterVariable = ShaderBuilder::DeclareVariable("_nzsl_counter", ShaderBuilder::ConstantValue(0u)); + auto counterVariable = ShaderBuilder::DeclareVariable("_nzsl_counter", ExpressionType{ PrimitiveType::UInt32 }, ShaderBuilder::ConstantValue(0u)); counterVariable->sourceLocation = forEachStatement.sourceLocation; counterVariable->varIndex = m_context->nextVariableIndex++; @@ -74,7 +74,7 @@ namespace nzsl::Ast elementVariable->varIndex = forEachStatement.varIndex; //< Preserve var index body->statements.emplace_back(std::move(elementVariable)); - body->statements.emplace_back(std::move(forEachStatement.statement)); + body->statements.emplace_back(Unscope(std::move(forEachStatement.statement))); auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(1u, forEachStatement.sourceLocation)); @@ -96,12 +96,7 @@ namespace nzsl::Ast Expression& fromExpr = *forStatement.fromExpr; const ExpressionType* fromExprType = GetExpressionType(fromExpr); if (!fromExprType) - { - if (!m_options->allowPartialSanitization) - throw AstInternalError{ forStatement.sourceLocation, "unexpected missing expression type" }; - return nullptr; - } const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType); if (!IsPrimitiveType(resolvedFromExprType)) @@ -123,7 +118,7 @@ namespace nzsl::Ast multi->statements.emplace_back(std::move(counterVariable)); // Target variable - auto targetVariable = ShaderBuilder::DeclareVariable("_nzsl_to", std::move(forStatement.toExpr)); + auto targetVariable = ShaderBuilder::DeclareVariable("_nzsl_to", ExpressionType{ counterType }, std::move(forStatement.toExpr)); targetVariable->sourceLocation = forStatement.sourceLocation; targetVariable->varIndex = m_context->nextVariableIndex++; @@ -135,7 +130,7 @@ namespace nzsl::Ast if (forStatement.stepExpr) { - auto stepVariable = ShaderBuilder::DeclareVariable("_nzsl_step", std::move(forStatement.stepExpr)); + auto stepVariable = ShaderBuilder::DeclareVariable("_nzsl_step", ExpressionType{ counterType }, std::move(forStatement.stepExpr)); stepVariable->sourceLocation = forStatement.sourceLocation; stepVariable->varIndex = m_context->nextVariableIndex++; @@ -173,7 +168,7 @@ namespace nzsl::Ast auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation), std::move(incrExpr)); incrCounter->sourceLocation = forStatement.sourceLocation; - body->statements.emplace_back(std::move(forStatement.statement)); + body->statements.emplace_back(Unscope(std::move(forStatement.statement))); body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); whileStatement->body = std::move(body); diff --git a/src/NZSL/Ast/Transformations/MatrixTransformer.cpp b/src/NZSL/Ast/Transformations/MatrixTransformer.cpp new file mode 100644 index 0000000..4cbad6a --- /dev/null +++ b/src/NZSL/Ast/Transformations/MatrixTransformer.cpp @@ -0,0 +1,218 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace nzsl::Ast +{ + bool MatrixTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) + { + m_options = &options; + + return TransformModule(module, context, error); + } + + ExpressionPtr MatrixTransformer::Transform(BinaryExpression&& binExpr) + { + if (!m_options->removeMatrixBinaryAddSub) + return nullptr; + + if (binExpr.op == BinaryType::Add || binExpr.op == BinaryType::Subtract) + { + const ExpressionType* leftExprType = GetExpressionType(*binExpr.left); + const ExpressionType* rightExprType = GetExpressionType(*binExpr.right); + if (IsMatrixType(*leftExprType) && IsMatrixType(*rightExprType)) + { + const MatrixType& matrixType = std::get(*leftExprType); + if (*leftExprType != *rightExprType) + throw AstInternalError{ binExpr.sourceLocation, "expected matrices of the same type" }; + + // Since we're going to access both matrices multiples times, make sure we cache them into variables if required + auto leftMatrix = CacheExpression(std::move(binExpr.left)); + auto rightMatrix = CacheExpression(std::move(binExpr.right)); + + std::vector columnExpressions(matrixType.columnCount); + + for (std::uint32_t i = 0; i < matrixType.columnCount; ++i) + { + // mat[i] + auto leftColumnExpr = ShaderBuilder::AccessIndex(Clone(*leftMatrix), ShaderBuilder::ConstantValue(i, binExpr.sourceLocation)); + leftColumnExpr->cachedExpressionType = VectorType{ matrixType.rowCount, matrixType.type }; + leftColumnExpr->sourceLocation = binExpr.sourceLocation; + + auto rightColumnExpr = ShaderBuilder::AccessIndex(Clone(*rightMatrix), ShaderBuilder::ConstantValue(i, binExpr.sourceLocation)); + rightColumnExpr->cachedExpressionType = VectorType{ matrixType.rowCount, matrixType.type }; + rightColumnExpr->sourceLocation = binExpr.sourceLocation; + + // lhs[i] [+|-] rhs[i] + auto binOp = ShaderBuilder::Binary(binExpr.op, std::move(leftColumnExpr), std::move(rightColumnExpr)); + binOp->cachedExpressionType = VectorType{ matrixType.rowCount, matrixType.type }; + binOp->sourceLocation = binExpr.sourceLocation; + + columnExpressions[i] = std::move(binOp); + } + + // Build resulting matrix + auto result = ShaderBuilder::Cast(*leftExprType, std::move(columnExpressions)); + result->cachedExpressionType = matrixType; + result->sourceLocation = binExpr.sourceLocation; + + return result; + } + } + + return nullptr; + } + + ExpressionPtr MatrixTransformer::Transform(CastExpression&& castExpr) + { + if (!m_options->removeMatrixCast) + return nullptr; + + if (!castExpr.targetType.IsResultingValue()) + { + if (m_context->allowPartialSanitization) + return nullptr; + + throw CompilerConstantExpressionRequiredError{ castExpr.targetType.GetExpression()->sourceLocation }; + } + + const ExpressionType& targetType = castExpr.targetType.GetResultingValue(); + + if (m_options->removeMatrixCast && IsMatrixType(targetType)) + { + const MatrixType& targetMatrixType = std::get(targetType); + + // Check if all types are known + for (std::size_t i = 0; i < castExpr.expressions.size(); ++i) + { + const ExpressionType* exprType = GetExpressionType(*castExpr.expressions[i]); + if (!exprType) + return nullptr; //< unresolved type + } + + const ExpressionType& resolvedFrontExprType = ResolveAlias(*GetExpressionType(*castExpr.expressions.front())); + bool isMatrixCast = IsMatrixType(resolvedFrontExprType); + if (isMatrixCast && std::get(resolvedFrontExprType) == targetMatrixType) + { + // Nothing to do + return std::move(castExpr.expressions.front()); + } + + auto* variableDeclaration = DeclareVariable("matrix", targetType, castExpr.sourceLocation); + + std::size_t variableIndex = *variableDeclaration->varIndex; + + ExpressionPtr cachedDiagonalValue; + + for (std::uint32_t i = 0; i < targetMatrixType.columnCount; ++i) + { + // temp[i] + auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, targetType, castExpr.sourceLocation), ShaderBuilder::ConstantValue(i, castExpr.sourceLocation)); + columnExpr->cachedExpressionType = VectorType{ targetMatrixType.rowCount, targetMatrixType.type }; + + // vector expression + ExpressionPtr vectorExpr; + std::size_t vectorComponentCount; + if (isMatrixCast) + { + if (!cachedDiagonalValue) + cachedDiagonalValue = CacheExpression(std::move(castExpr.expressions.front())); + + const MatrixType& fromMatrixType = std::get(resolvedFrontExprType); + + // fromMatrix[i] + auto matrixColumnExpr = ShaderBuilder::AccessIndex(Clone(*cachedDiagonalValue), ShaderBuilder::ConstantValue(i, castExpr.sourceLocation)); + matrixColumnExpr->cachedExpressionType = VectorType{ fromMatrixType.rowCount, fromMatrixType.type }; + + vectorExpr = std::move(matrixColumnExpr); + vectorComponentCount = fromMatrixType.rowCount; + } + else if (IsVectorType(resolvedFrontExprType)) + { + // parameter #i + vectorExpr = std::move(castExpr.expressions[i]); + vectorComponentCount = std::get(ResolveAlias(*GetExpressionType(*vectorExpr))).componentCount; + } + else + { + assert(IsPrimitiveType(resolvedFrontExprType)); + + // Use a Cast expression to replace swizzle + std::vector expressions(targetMatrixType.rowCount); + SourceLocation location; + for (std::size_t j = 0; j < targetMatrixType.rowCount; ++j) + { + if (castExpr.expressions.size() == 1) //< diagonal value + { + if (i == j) + expressions[j] = Clone(*cachedDiagonalValue); + else + expressions[j] = ShaderBuilder::ConstantValue(ExpressionType{ targetMatrixType.type }, 0, castExpr.sourceLocation); + } + else + expressions[j] = std::move(castExpr.expressions[i * targetMatrixType.rowCount + j]); + + if (j == 0) + location = expressions[j]->sourceLocation; + else + location.ExtendToRight(expressions[j]->sourceLocation); + } + + auto buildVec = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); + buildVec->sourceLocation = location; + + vectorExpr = std::move(buildVec); + vectorComponentCount = targetMatrixType.rowCount; + } + + // cast expression (turn fromMatrix[i] to vec3[f32](fromMatrix[i])) + ExpressionPtr columnCastExpr; + if (vectorComponentCount != targetMatrixType.rowCount) + { + CastExpressionPtr vecCast; + if (vectorComponentCount < targetMatrixType.rowCount) + { + std::vector expressions; + expressions.push_back(std::move(vectorExpr)); + for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j) + expressions.push_back(ShaderBuilder::ConstantValue(ExpressionType{ targetMatrixType.type }, (i == j + vectorComponentCount) ? 1 : 0, castExpr.sourceLocation)); //< set 1 to diagonal + + vecCast = ShaderBuilder::Cast(ExpressionType{ VectorType{ targetMatrixType.rowCount, targetMatrixType.type } }, std::move(expressions)); + vecCast->sourceLocation = castExpr.sourceLocation; + + columnCastExpr = std::move(vecCast); + } + else + { + std::array swizzleComponents; + std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0); + + auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount); + swizzleExpr->sourceLocation = castExpr.sourceLocation; + + columnCastExpr = std::move(swizzleExpr); + } + } + else + columnCastExpr = std::move(vectorExpr); + + columnCastExpr->cachedExpressionType = VectorType{ targetMatrixType.rowCount, targetMatrixType.type }; + + // temp[i] = columnCastExpr + auto assignExpr = ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(columnCastExpr)); + assignExpr->sourceLocation = castExpr.sourceLocation; + + AppendStatement(ShaderBuilder::ExpressionStatement(std::move(assignExpr))); + } + + return ShaderBuilder::Variable(variableIndex, targetType, castExpr.sourceLocation); + } + + return nullptr; + } +} diff --git a/src/NZSL/Ast/Transformations/StatementTransformer.cpp b/src/NZSL/Ast/Transformations/StatementTransformer.cpp deleted file mode 100644 index fbd9c48..0000000 --- a/src/NZSL/Ast/Transformations/StatementTransformer.cpp +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) -// This file is part of the "Nazara Shading Language" project -// For conditions of distribution and use, see copyright notice in Config.hpp - -#include - -namespace nzsl::Ast -{ -#define NZSL_SHADERAST_STATEMENT(Node) \ - StatementPtr StatementTransformer::Transform(Node##Statement&& /*statement*/) \ - { \ - return nullptr; \ - } - -#include - - bool StatementTransformer::TransformModule(Module& module, Context& context, std::string* error) - { - m_context = &context; - - try - { - StatementPtr root = std::move(module.rootNode); - TransformStatement(root); - module.rootNode = Nz::StaticUniquePointerCast(std::move(root)); - } - catch(const std::exception& e) - { - if (!error) - throw; - - *error = e.what(); - return false; - } - - return true; - } - - void StatementTransformer::TransformStatement(StatementPtr& statement) - { - assert(statement); - - m_statementStack.push_back(&statement); - statement->Visit(*this); - m_statementStack.pop_back(); - } - - template - bool StatementTransformer::TransformCurrent() - { - StatementPtr newStatement = Transform(std::move(Nz::SafeCast(**m_statementStack.back()))); - if (!newStatement) - return false; - - *m_statementStack.back() = std::move(newStatement); - return true; - } - - void StatementTransformer::Visit(BranchStatement& node) - { - if (TransformCurrent()) - return; - - for (auto& cond : node.condStatements) - TransformStatement(cond.statement); - - if (node.elseStatement) - TransformStatement(node.elseStatement); - } - - void StatementTransformer::Visit(BreakStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ConditionalStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ContinueStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareAliasStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareConstStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareExternalStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareFunctionStatement& node) - { - if (TransformCurrent()) - return; - - for (auto& statement : node.statements) - TransformStatement(statement); - } - - void StatementTransformer::Visit(DeclareOptionStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareStructStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DeclareVariableStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(DiscardStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ExpressionStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ForStatement& node) - { - if (TransformCurrent()) - return; - - if (node.statement) - TransformStatement(node.statement); - } - - void StatementTransformer::Visit(ForEachStatement& node) - { - if (TransformCurrent()) - return; - - if (node.statement) - TransformStatement(node.statement); - } - - void StatementTransformer::Visit(ImportStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(MultiStatement& node) - { - if (TransformCurrent()) - return; - - for (auto& statement : node.statements) - TransformStatement(statement); - } - - void StatementTransformer::Visit(NoOpStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ReturnStatement& /*node*/) - { - TransformCurrent(); - } - - void StatementTransformer::Visit(ScopedStatement& node) - { - if (TransformCurrent()) - return; - - if (node.statement) - TransformStatement(node.statement); - } - - void StatementTransformer::Visit(WhileStatement& node) - { - if (TransformCurrent()) - return; - - if (node.body) - TransformStatement(node.body); - } -} diff --git a/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp b/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp new file mode 100644 index 0000000..127cd8f --- /dev/null +++ b/src/NZSL/Ast/Transformations/SwizzleTransformer.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace nzsl::Ast +{ + bool SwizzleTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) + { + m_options = &options; + + return TransformModule(module, context, error); + } + + ExpressionPtr SwizzleTransformer::Transform(SwizzleExpression&& swizzle) + { + const ExpressionType* exprType = GetExpressionType(*swizzle.expression); + if (!exprType) + return nullptr; + + const ExpressionType& resolvedExprType = ResolveAlias(*exprType); + + if (m_options->removeScalarSwizzling && IsPrimitiveType(resolvedExprType)) + { + for (std::size_t i = 0; i < swizzle.componentCount; ++i) + { + if (swizzle.components[i] != 0) + throw CompilerInvalidScalarSwizzleError{ swizzle.sourceLocation }; + } + + if (swizzle.componentCount == 1) + return std::move(swizzle.expression); //< remove swizzle expression (a.x => a) + + // Use a Cast expression to replace swizzle + ExpressionPtr expression = CacheExpression(std::move(swizzle.expression)); //< Since we are going to use a value multiple times, cache it if required + + PrimitiveType baseType = std::get(resolvedExprType); + + auto cast = std::make_unique(); + cast->sourceLocation = swizzle.sourceLocation; + cast->targetType = ExpressionType{ VectorType{ swizzle.componentCount, baseType } }; + cast->cachedExpressionType = swizzle.cachedExpressionType; + + cast->expressions.reserve(swizzle.componentCount); + for (std::size_t j = 0; j < swizzle.componentCount; ++j) + cast->expressions.push_back(Clone(*expression)); + + return cast; + } + + return nullptr; + } +} diff --git a/src/NZSL/Ast/Transformations/Transformer.cpp b/src/NZSL/Ast/Transformations/Transformer.cpp new file mode 100644 index 0000000..bbe37ec --- /dev/null +++ b/src/NZSL/Ast/Transformations/Transformer.cpp @@ -0,0 +1,515 @@ +// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Shading Language" project +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace nzsl::Ast +{ + StatementPtr Transformer::Unscope(StatementPtr&& statement) + { + if (statement->GetType() == NodeType::ScopedStatement) + return std::move(static_cast(*statement).statement); + else + return std::move(statement); + } + + void Transformer::AppendStatement(StatementPtr statement) + { + m_currentStatementList->insert(m_currentStatementList->begin() + m_currentStatementListIndex, std::move(statement)); + m_currentStatementListIndex++; + } + + ExpressionPtr Transformer::CacheExpression(ExpressionPtr expression) + { + assert(expression); + + // No need to cache LValues (variables/constants) (TODO: Improve this, as constants don't need to be cached as well) + if (GetExpressionCategory(*expression) == ExpressionCategory::LValue) + return expression; + + DeclareVariableStatement* variableDeclaration = DeclareVariable("cachedResult", std::move(expression)); + + auto varExpr = std::make_unique(); + varExpr->sourceLocation = variableDeclaration->sourceLocation; + varExpr->variableId = *variableDeclaration->varIndex; + + return varExpr; + } + + DeclareVariableStatement* Transformer::DeclareVariable(std::string_view name, ExpressionPtr initialExpr) + { + DeclareVariableStatement* var = DeclareVariable(name, *GetExpressionType(*initialExpr, false), initialExpr->sourceLocation); + var->initialExpression = std::move(initialExpr); + + return var; + } + + DeclareVariableStatement* Transformer::DeclareVariable(std::string_view name, Ast::ExpressionType type, SourceLocation sourceLocation) + { + assert(m_currentStatementList); + + auto variableDeclaration = ShaderBuilder::DeclareVariable(fmt::format("_nzsl_{}", name), nullptr); + variableDeclaration->sourceLocation = std::move(sourceLocation); + variableDeclaration->varIndex = m_context->nextVariableIndex++; + variableDeclaration->varType = std::move(type); + + DeclareVariableStatement* varPtr = variableDeclaration.get(); + AppendStatement(std::move(variableDeclaration)); + + return varPtr; + } + + const ExpressionType* Transformer::GetExpressionType(Expression& expr) const + { + return GetExpressionType(expr, m_context->allowPartialSanitization); + } + + const ExpressionType* Transformer::GetExpressionType(Expression& expr, bool allowEmpty) const + { + const ExpressionType* expressionType = Ast::GetExpressionType(expr); + if (!expressionType) + { + if (!allowEmpty) + throw AstInternalError{ expr.sourceLocation, "unexpected missing expression type" }; + } + + return expressionType; + } + +#define NZSL_SHADERAST_NODE(Node, Type) \ + Type##Ptr Transformer::Transform(Node##Type&& /*node*/) \ + { \ + return nullptr; \ + } + +#include + + void Transformer::TransformExpression(ExpressionPtr& expression) + { + assert(expression); + + m_expressionStack.push_back(&expression); + expression->Visit(*this); + m_expressionStack.pop_back(); + } + + bool Transformer::TransformModule(Module& module, Context& context, std::string* error) + { + m_context = &context; + + try + { + StatementPtr root = std::move(module.rootNode); + TransformStatement(root); + module.rootNode = Nz::StaticUniquePointerCast(std::move(root)); + } + catch(const std::exception& e) + { + if (!error) + throw; + + *error = e.what(); + return false; + } + + return true; + } + + void Transformer::TransformStatement(StatementPtr& statement) + { + assert(statement); + + m_statementStack.push_back(&statement); + statement->Visit(*this); + m_statementStack.pop_back(); + } + + template + bool Transformer::TransformCurrentExpression() + { + ExpressionPtr newExpression = Transform(std::move(Nz::SafeCast(**m_expressionStack.back()))); + if (!newExpression) + return false; + + *m_expressionStack.back() = std::move(newExpression); + return true; + } + + template + bool Transformer::TransformCurrentStatement() + { + StatementPtr newStatement = Transform(std::move(Nz::SafeCast(**m_statementStack.back()))); + if (!newStatement) + return false; + + *m_statementStack.back() = std::move(newStatement); + return true; + } + + void Transformer::Visit(AccessIdentifierExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.expr); + } + + void Transformer::Visit(AccessIndexExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.expr); + for (auto& index : node.indices) + TransformExpression(index); + } + + void Transformer::Visit(AliasValueExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(AssignExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.left); + TransformExpression(node.right); + } + + void Transformer::Visit(BinaryExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.left); + TransformExpression(node.right); + } + + void Transformer::Visit(CallFunctionExpression& node) + { + if (TransformCurrentExpression()) + return; + + for (auto& param : node.parameters) + TransformExpression(param); + + TransformExpression(node.targetFunction); + } + + void Transformer::Visit(CallMethodExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.object); + + for (auto& param : node.parameters) + TransformExpression(param); + } + + void Transformer::Visit(CastExpression& node) + { + if (TransformCurrentExpression()) + return; + + for (auto& expr : node.expressions) + TransformExpression(expr); + } + + void Transformer::Visit(ConditionalExpression& node) + { + if (TransformCurrentExpression()) + return; + + TransformExpression(node.truePath); + TransformExpression(node.falsePath); + } + + void Transformer::Visit(ConstantExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(ConstantArrayValueExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(ConstantValueExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(FunctionExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(IdentifierExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(IntrinsicExpression& node) + { + if (TransformCurrentExpression()) + return; + + for (auto& param : node.parameters) + TransformExpression(param); + } + + void Transformer::Visit(IntrinsicFunctionExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(StructTypeExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(SwizzleExpression& node) + { + if (TransformCurrentExpression()) + return; + + if (node.expression) + TransformExpression(node.expression); + } + + void Transformer::Visit(TypeExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(VariableValueExpression& /*node*/) + { + TransformCurrentExpression(); + } + + void Transformer::Visit(UnaryExpression& node) + { + if (TransformCurrentExpression()) + return; + + if (node.expression) + TransformExpression(node.expression); + } + + void Transformer::Visit(BranchStatement& node) + { + if (TransformCurrentStatement()) + return; + + for (auto& cond : node.condStatements) + { + TransformStatement(cond.statement); + + if (m_visitExpressions) + TransformExpression(cond.condition); + } + + if (node.elseStatement) + TransformStatement(node.elseStatement); + } + + void Transformer::Visit(BreakStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(ConditionalStatement& node) + { + if (TransformCurrentStatement()) + return; + + TransformStatement(node.statement); + + if (m_visitExpressions) + TransformExpression(node.condition); + } + + void Transformer::Visit(ContinueStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(DeclareAliasStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + TransformExpression(node.expression); + } + + void Transformer::Visit(DeclareConstStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + TransformExpression(node.expression); + } + + void Transformer::Visit(DeclareExternalStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(DeclareFunctionStatement& node) + { + if (TransformCurrentStatement()) + return; + + HandleStatementList(node.statements, [&](StatementPtr& statement) + { + TransformStatement(statement); + }); + } + + void Transformer::Visit(DeclareOptionStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + { + if (node.defaultValue) + TransformExpression(node.defaultValue); + } + } + + void Transformer::Visit(DeclareStructStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(DeclareVariableStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + { + if (node.initialExpression) + TransformExpression(node.initialExpression); + } + } + + void Transformer::Visit(DiscardStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(ExpressionStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + TransformExpression(node.expression); + } + + void Transformer::Visit(ForStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (node.statement) + TransformStatement(node.statement); + + if (m_visitExpressions) + { + TransformExpression(node.fromExpr); + TransformExpression(node.toExpr); + + if (node.stepExpr) + TransformExpression(node.stepExpr); + } + } + + void Transformer::Visit(ForEachStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (node.statement) + TransformStatement(node.statement); + + if (m_visitExpressions) + TransformExpression(node.expression); + } + + void Transformer::Visit(ImportStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(MultiStatement& node) + { + if (TransformCurrentStatement()) + return; + + HandleStatementList(node.statements, [&](StatementPtr& statement) + { + TransformStatement(statement); + }); + } + + void Transformer::Visit(NoOpStatement& /*node*/) + { + TransformCurrentStatement(); + } + + void Transformer::Visit(ReturnStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (m_visitExpressions) + TransformExpression(node.returnExpr); + } + + void Transformer::Visit(ScopedStatement& node) + { + if (TransformCurrentStatement()) + return; + + std::vector statementList; + HandleStatementList(statementList, [&] + { + TransformStatement(node.statement); + }); + + // To handle the case where our scoped statement does not contain a statement list but requires + // a new variable to be introduced, we need to be able to add a MultiStatement automatically + if (!statementList.empty()) + { + // Turn the scoped statement into a scoped + multi statement + statementList.push_back(std::move(node.statement)); + + node.statement = ShaderBuilder::MultiStatement(std::move(statementList)); + node.statement->sourceLocation = node.sourceLocation; + } + } + + void Transformer::Visit(WhileStatement& node) + { + if (TransformCurrentStatement()) + return; + + if (node.body) + TransformStatement(node.body); + + if (m_visitExpressions) + TransformExpression(node.condition); + } +}