-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Compiler: Begin work to split sanitizer into multiple passes
- Loading branch information
Showing
6 changed files
with
501 additions
and
0 deletions.
There are no files selected for viewing
40 changes: 40 additions & 0 deletions
40
include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) | ||
// This file is part of the "Nazara Shading Language" project | ||
// For conditions of distribution and use, see copyright notice in Config.hpp | ||
|
||
#pragma once | ||
|
||
#ifndef NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP | ||
#define NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP | ||
|
||
#include <NZSL/Ast/Transformations/StatementTransformer.hpp> | ||
|
||
namespace nzsl::Ast | ||
{ | ||
class NZSL_API ForToWhileTransformer final : public StatementTransformer | ||
{ | ||
public: | ||
struct Options; | ||
|
||
inline bool Transform(Module& module, Context& context, std::string* error = nullptr); | ||
bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr); | ||
|
||
struct Options | ||
{ | ||
bool allowPartialSanitization = false; | ||
bool reduceForEachLoopsToWhile = true; | ||
bool reduceForLoopsToWhile = true; | ||
}; | ||
|
||
private: | ||
using StatementTransformer::Transform; | ||
StatementPtr Transform(ForEachStatement&& statement) override; | ||
StatementPtr Transform(ForStatement&& statement) override; | ||
|
||
const Options* m_options; | ||
}; | ||
} | ||
|
||
#include <NZSL/Ast/Transformations/ForToWhileTransformer.inl> | ||
|
||
#endif // NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP |
11 changes: 11 additions & 0 deletions
11
include/NZSL/Ast/Transformations/ForToWhileTransformer.inl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) | ||
// This file is part of the "Nazara Shading Language" project | ||
// For conditions of distribution and use, see copyright notice in Config.hpp | ||
|
||
namespace nzsl::Ast | ||
{ | ||
inline bool ForToWhileTransformer::Transform(Module& module, Context& context, std::string* error) | ||
{ | ||
return Transform(module, context, {}, error); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) | ||
// This file is part of the "Nazara Shading Language" project | ||
// For conditions of distribution and use, see copyright notice in Config.hpp | ||
|
||
#pragma once | ||
|
||
#ifndef NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP | ||
#define NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP | ||
|
||
#include <NZSL/Config.hpp> | ||
#include <NZSL/Ast/Module.hpp> | ||
#include <NZSL/Ast/StatementVisitor.hpp> | ||
|
||
namespace nzsl::Ast | ||
{ | ||
class NZSL_API StatementTransformer : public StatementVisitor | ||
{ | ||
public: | ||
struct Context | ||
{ | ||
std::size_t nextVariableIndex; | ||
}; | ||
|
||
protected: | ||
#define NZSL_SHADERAST_STATEMENT(Node) virtual StatementPtr Transform(Node##Statement&& statement); | ||
#include <NZSL/Ast/NodeList.hpp> | ||
|
||
bool TransformModule(Module& module, Context& context, std::string* error = nullptr); | ||
void TransformStatement(StatementPtr& statement); | ||
|
||
Context* m_context; | ||
|
||
private: | ||
template<typename T> bool TransformCurrent(); | ||
|
||
void Visit(BranchStatement& node) override; | ||
void Visit(BreakStatement& node) override; | ||
void Visit(ConditionalStatement& node) override; | ||
void Visit(ContinueStatement& node) override; | ||
void Visit(DeclareAliasStatement& node) override; | ||
void Visit(DeclareConstStatement& node) override; | ||
void Visit(DeclareExternalStatement& node) override; | ||
void Visit(DeclareFunctionStatement& node) override; | ||
void Visit(DeclareOptionStatement& node) override; | ||
void Visit(DeclareStructStatement& node) override; | ||
void Visit(DeclareVariableStatement& node) override; | ||
void Visit(DiscardStatement& node) override; | ||
void Visit(ExpressionStatement& node) override; | ||
void Visit(ForStatement& node) override; | ||
void Visit(ForEachStatement& node) override; | ||
void Visit(ImportStatement& node) override; | ||
void Visit(MultiStatement& node) override; | ||
void Visit(NoOpStatement& node) override; | ||
void Visit(ReturnStatement& node) override; | ||
void Visit(ScopedStatement& node) override; | ||
void Visit(WhileStatement& node) override; | ||
|
||
std::vector<StatementPtr*> m_statementStack; | ||
}; | ||
} | ||
|
||
#include <NZSL/Ast/Transformations/StatementTransformer.inl> | ||
|
||
#endif // NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) | ||
// This file is part of the "Nazara Shading Language" project | ||
// For conditions of distribution and use, see copyright notice in Config.hpp | ||
|
||
namespace nzsl::Ast | ||
{ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com) | ||
// This file is part of the "Nazara Shading Language" project | ||
// For conditions of distribution and use, see copyright notice in Config.hpp | ||
|
||
#include <NZSL/Ast/Transformations/ForToWhileTransformer.hpp> | ||
#include <NZSL/Lang/Errors.hpp> | ||
|
||
namespace nzsl::Ast | ||
{ | ||
bool ForToWhileTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error) | ||
{ | ||
m_options = &options; | ||
|
||
return TransformModule(module, context, error); | ||
} | ||
|
||
StatementPtr ForToWhileTransformer::Transform(ForEachStatement&& forEachStatement) | ||
{ | ||
if (!m_options->reduceForEachLoopsToWhile) | ||
return nullptr; | ||
|
||
const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression); | ||
if (!exprType) | ||
{ | ||
if (!m_options->allowPartialSanitization) | ||
throw AstInternalError{ forEachStatement.sourceLocation, "unexpected missing expression type" }; | ||
|
||
return nullptr; | ||
} | ||
|
||
const ExpressionType& resolvedExprType = ResolveAlias(*exprType); | ||
|
||
ExpressionType innerType; | ||
if (IsArrayType(resolvedExprType)) | ||
{ | ||
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType); | ||
innerType = arrayType.containedType->type; | ||
} | ||
else | ||
throw CompilerForEachUnsupportedTypeError{ forEachStatement.sourceLocation, ToString(*exprType) }; | ||
|
||
auto multi = std::make_unique<MultiStatement>(); | ||
multi->sourceLocation = forEachStatement.sourceLocation; | ||
|
||
if (IsArrayType(resolvedExprType)) | ||
{ | ||
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType); | ||
|
||
multi->statements.reserve(2); | ||
|
||
// Counter variable | ||
auto counterVariable = ShaderBuilder::DeclareVariable("_nzsl_counter", ShaderBuilder::ConstantValue(0u)); | ||
counterVariable->sourceLocation = forEachStatement.sourceLocation; | ||
counterVariable->varIndex = m_context->nextVariableIndex++; | ||
|
||
std::size_t counterVarIndex = counterVariable->varIndex.value(); | ||
|
||
multi->statements.emplace_back(std::move(counterVariable)); | ||
|
||
auto whileStatement = std::make_unique<WhileStatement>(); | ||
whileStatement->unroll = std::move(forEachStatement.unroll); | ||
|
||
// While condition | ||
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(arrayType.length, forEachStatement.sourceLocation)); | ||
whileStatement->condition = std::move(condition); | ||
|
||
// While body | ||
auto body = std::make_unique<MultiStatement>(); | ||
body->statements.reserve(3); | ||
|
||
auto accessIndex = ShaderBuilder::AccessIndex(std::move(forEachStatement.expression), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation)); | ||
|
||
auto elementVariable = ShaderBuilder::DeclareVariable(forEachStatement.varName, std::move(accessIndex)); | ||
elementVariable->varIndex = forEachStatement.varIndex; //< Preserve var index | ||
|
||
body->statements.emplace_back(std::move(elementVariable)); | ||
body->statements.emplace_back(std::move(forEachStatement.statement)); | ||
|
||
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(1u, forEachStatement.sourceLocation)); | ||
|
||
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); | ||
|
||
whileStatement->body = std::move(body); | ||
|
||
multi->statements.emplace_back(std::move(whileStatement)); | ||
} | ||
|
||
return ShaderBuilder::Scoped(std::move(multi)); | ||
} | ||
|
||
StatementPtr ForToWhileTransformer::Transform(ForStatement&& forStatement) | ||
{ | ||
if (!m_options->reduceForLoopsToWhile) | ||
return nullptr; | ||
|
||
Expression& fromExpr = *forStatement.fromExpr; | ||
const ExpressionType* fromExprType = GetExpressionType(fromExpr); | ||
if (!fromExprType) | ||
{ | ||
if (!m_options->allowPartialSanitization) | ||
throw AstInternalError{ forStatement.sourceLocation, "unexpected missing expression type" }; | ||
|
||
return nullptr; | ||
} | ||
|
||
const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType); | ||
if (!IsPrimitiveType(resolvedFromExprType)) | ||
throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) }; | ||
|
||
PrimitiveType counterType = std::get<PrimitiveType>(resolvedFromExprType); | ||
if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32) | ||
throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) }; | ||
|
||
auto multi = std::make_unique<MultiStatement>(); | ||
multi->sourceLocation = forStatement.sourceLocation; | ||
|
||
// Counter variable | ||
auto counterVariable = ShaderBuilder::DeclareVariable(forStatement.varName, std::move(forStatement.fromExpr)); | ||
counterVariable->sourceLocation = forStatement.sourceLocation; | ||
counterVariable->varIndex = forStatement.varIndex; | ||
|
||
std::size_t counterVarIndex = counterVariable->varIndex.value(); | ||
multi->statements.emplace_back(std::move(counterVariable)); | ||
|
||
// Target variable | ||
auto targetVariable = ShaderBuilder::DeclareVariable("_nzsl_to", std::move(forStatement.toExpr)); | ||
targetVariable->sourceLocation = forStatement.sourceLocation; | ||
targetVariable->varIndex = m_context->nextVariableIndex++; | ||
|
||
std::size_t targetVarIndex = targetVariable->varIndex.value(); | ||
multi->statements.emplace_back(std::move(targetVariable)); | ||
|
||
// Step variable | ||
std::optional<std::size_t> stepVarIndex; | ||
|
||
if (forStatement.stepExpr) | ||
{ | ||
auto stepVariable = ShaderBuilder::DeclareVariable("_nzsl_step", std::move(forStatement.stepExpr)); | ||
stepVariable->sourceLocation = forStatement.sourceLocation; | ||
stepVariable->varIndex = m_context->nextVariableIndex++; | ||
|
||
stepVarIndex = stepVariable->varIndex; | ||
multi->statements.emplace_back(std::move(stepVariable)); | ||
} | ||
|
||
// While | ||
auto whileStatement = std::make_unique<WhileStatement>(); | ||
whileStatement->sourceLocation = forStatement.sourceLocation; | ||
whileStatement->unroll = std::move(forStatement.unroll); | ||
|
||
// While condition | ||
auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation); | ||
auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType, forStatement.sourceLocation); | ||
|
||
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable)); | ||
condition->sourceLocation = forStatement.sourceLocation; | ||
|
||
whileStatement->condition = std::move(condition); | ||
|
||
// While body | ||
auto body = std::make_unique<MultiStatement>(); | ||
body->statements.reserve(2); | ||
|
||
// Counter and increment | ||
ExpressionPtr incrExpr; | ||
if (stepVarIndex) | ||
incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType, forStatement.sourceLocation); | ||
else | ||
incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::ConstantValue(1, forStatement.sourceLocation) : ShaderBuilder::ConstantValue(1u, forStatement.sourceLocation); | ||
|
||
incrExpr->sourceLocation = forStatement.sourceLocation; | ||
|
||
auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation), std::move(incrExpr)); | ||
incrCounter->sourceLocation = forStatement.sourceLocation; | ||
|
||
body->statements.emplace_back(std::move(forStatement.statement)); | ||
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter))); | ||
|
||
whileStatement->body = std::move(body); | ||
|
||
multi->statements.emplace_back(std::move(whileStatement)); | ||
|
||
return ShaderBuilder::Scoped(std::move(multi)); | ||
} | ||
} |
Oops, something went wrong.