Skip to content

Commit

Permalink
Compiler: Begin work to split sanitizer into multiple passes
Browse files Browse the repository at this point in the history
  • Loading branch information
SirLynix committed Jun 21, 2024
1 parent d3d9eeb commit 7b608cd
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 0 deletions.
40 changes: 40 additions & 0 deletions include/NZSL/Ast/Transformations/ForToWhileTransformer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Shading Language" project
// For conditions of distribution and use, see copyright notice in Config.hpp

#pragma once

#ifndef NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP
#define NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP

#include <NZSL/Ast/Transformations/StatementTransformer.hpp>

namespace nzsl::Ast
{
class NZSL_API ForToWhileTransformer final : public StatementTransformer
{
public:
struct Options;

inline bool Transform(Module& module, Context& context, std::string* error = nullptr);
bool Transform(Module& module, Context& context, const Options& options, std::string* error = nullptr);

struct Options
{
bool allowPartialSanitization = false;
bool reduceForEachLoopsToWhile = true;
bool reduceForLoopsToWhile = true;
};

private:
using StatementTransformer::Transform;
StatementPtr Transform(ForEachStatement&& statement) override;
StatementPtr Transform(ForStatement&& statement) override;

const Options* m_options;
};
}

#include <NZSL/Ast/Transformations/ForToWhileTransformer.inl>

#endif // NZSL_AST_TRANSFORMATIONS_FORTOWHILETRANSFORMER_HPP
11 changes: 11 additions & 0 deletions include/NZSL/Ast/Transformations/ForToWhileTransformer.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Shading Language" project
// For conditions of distribution and use, see copyright notice in Config.hpp

namespace nzsl::Ast
{
inline bool ForToWhileTransformer::Transform(Module& module, Context& context, std::string* error)
{
return Transform(module, context, {}, error);
}
}
64 changes: 64 additions & 0 deletions include/NZSL/Ast/Transformations/StatementTransformer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Shading Language" project
// For conditions of distribution and use, see copyright notice in Config.hpp

#pragma once

#ifndef NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP
#define NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP

#include <NZSL/Config.hpp>
#include <NZSL/Ast/Module.hpp>
#include <NZSL/Ast/StatementVisitor.hpp>

namespace nzsl::Ast
{
class NZSL_API StatementTransformer : public StatementVisitor
{
public:
struct Context
{
std::size_t nextVariableIndex;
};

protected:
#define NZSL_SHADERAST_STATEMENT(Node) virtual StatementPtr Transform(Node##Statement&& statement);
#include <NZSL/Ast/NodeList.hpp>

bool TransformModule(Module& module, Context& context, std::string* error = nullptr);
void TransformStatement(StatementPtr& statement);

Context* m_context;

private:
template<typename T> bool TransformCurrent();

void Visit(BranchStatement& node) override;
void Visit(BreakStatement& node) override;
void Visit(ConditionalStatement& node) override;
void Visit(ContinueStatement& node) override;
void Visit(DeclareAliasStatement& node) override;
void Visit(DeclareConstStatement& node) override;
void Visit(DeclareExternalStatement& node) override;
void Visit(DeclareFunctionStatement& node) override;
void Visit(DeclareOptionStatement& node) override;
void Visit(DeclareStructStatement& node) override;
void Visit(DeclareVariableStatement& node) override;
void Visit(DiscardStatement& node) override;
void Visit(ExpressionStatement& node) override;
void Visit(ForStatement& node) override;
void Visit(ForEachStatement& node) override;
void Visit(ImportStatement& node) override;
void Visit(MultiStatement& node) override;
void Visit(NoOpStatement& node) override;
void Visit(ReturnStatement& node) override;
void Visit(ScopedStatement& node) override;
void Visit(WhileStatement& node) override;

std::vector<StatementPtr*> m_statementStack;
};
}

#include <NZSL/Ast/Transformations/StatementTransformer.inl>

#endif // NZSL_AST_TRANSFORMATIONS_STATEMENTTRANSFORMER_HPP
7 changes: 7 additions & 0 deletions include/NZSL/Ast/Transformations/StatementTransformer.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Shading Language" project
// For conditions of distribution and use, see copyright notice in Config.hpp

namespace nzsl::Ast
{
}
185 changes: 185 additions & 0 deletions src/NZSL/Ast/Transformations/ForToWhileTransformer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright (C) 2024 Jérôme "SirLynix" Leclercq (lynix680@gmail.com)
// This file is part of the "Nazara Shading Language" project
// For conditions of distribution and use, see copyright notice in Config.hpp

#include <NZSL/Ast/Transformations/ForToWhileTransformer.hpp>
#include <NZSL/Lang/Errors.hpp>

namespace nzsl::Ast
{
bool ForToWhileTransformer::Transform(Module& module, Context& context, const Options& options, std::string* error)
{
m_options = &options;

return TransformModule(module, context, error);
}

StatementPtr ForToWhileTransformer::Transform(ForEachStatement&& forEachStatement)
{
if (!m_options->reduceForEachLoopsToWhile)
return nullptr;

const ExpressionType* exprType = GetExpressionType(*forEachStatement.expression);
if (!exprType)
{
if (!m_options->allowPartialSanitization)
throw AstInternalError{ forEachStatement.sourceLocation, "unexpected missing expression type" };

return nullptr;
}

const ExpressionType& resolvedExprType = ResolveAlias(*exprType);

ExpressionType innerType;
if (IsArrayType(resolvedExprType))
{
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType);
innerType = arrayType.containedType->type;
}
else
throw CompilerForEachUnsupportedTypeError{ forEachStatement.sourceLocation, ToString(*exprType) };

auto multi = std::make_unique<MultiStatement>();
multi->sourceLocation = forEachStatement.sourceLocation;

if (IsArrayType(resolvedExprType))
{
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType);

multi->statements.reserve(2);

// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable("_nzsl_counter", ShaderBuilder::ConstantValue(0u));
counterVariable->sourceLocation = forEachStatement.sourceLocation;
counterVariable->varIndex = m_context->nextVariableIndex++;

std::size_t counterVarIndex = counterVariable->varIndex.value();

multi->statements.emplace_back(std::move(counterVariable));

auto whileStatement = std::make_unique<WhileStatement>();
whileStatement->unroll = std::move(forEachStatement.unroll);

// While condition
auto condition = ShaderBuilder::Binary(BinaryType::CompLt, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(arrayType.length, forEachStatement.sourceLocation));
whileStatement->condition = std::move(condition);

// While body
auto body = std::make_unique<MultiStatement>();
body->statements.reserve(3);

auto accessIndex = ShaderBuilder::AccessIndex(std::move(forEachStatement.expression), ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation));

auto elementVariable = ShaderBuilder::DeclareVariable(forEachStatement.varName, std::move(accessIndex));
elementVariable->varIndex = forEachStatement.varIndex; //< Preserve var index

body->statements.emplace_back(std::move(elementVariable));
body->statements.emplace_back(std::move(forEachStatement.statement));

auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, PrimitiveType::UInt32, forEachStatement.sourceLocation), ShaderBuilder::ConstantValue(1u, forEachStatement.sourceLocation));

body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));

whileStatement->body = std::move(body);

multi->statements.emplace_back(std::move(whileStatement));
}

return ShaderBuilder::Scoped(std::move(multi));
}

StatementPtr ForToWhileTransformer::Transform(ForStatement&& forStatement)
{
if (!m_options->reduceForLoopsToWhile)
return nullptr;

Expression& fromExpr = *forStatement.fromExpr;
const ExpressionType* fromExprType = GetExpressionType(fromExpr);
if (!fromExprType)
{
if (!m_options->allowPartialSanitization)
throw AstInternalError{ forStatement.sourceLocation, "unexpected missing expression type" };

return nullptr;
}

const ExpressionType& resolvedFromExprType = ResolveAlias(*fromExprType);
if (!IsPrimitiveType(resolvedFromExprType))
throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) };

PrimitiveType counterType = std::get<PrimitiveType>(resolvedFromExprType);
if (counterType != PrimitiveType::Int32 && counterType != PrimitiveType::UInt32)
throw CompilerForFromTypeExpectIntegerTypeError{ fromExpr.sourceLocation, ToString(*fromExprType) };

auto multi = std::make_unique<MultiStatement>();
multi->sourceLocation = forStatement.sourceLocation;

// Counter variable
auto counterVariable = ShaderBuilder::DeclareVariable(forStatement.varName, std::move(forStatement.fromExpr));
counterVariable->sourceLocation = forStatement.sourceLocation;
counterVariable->varIndex = forStatement.varIndex;

std::size_t counterVarIndex = counterVariable->varIndex.value();
multi->statements.emplace_back(std::move(counterVariable));

// Target variable
auto targetVariable = ShaderBuilder::DeclareVariable("_nzsl_to", std::move(forStatement.toExpr));
targetVariable->sourceLocation = forStatement.sourceLocation;
targetVariable->varIndex = m_context->nextVariableIndex++;

std::size_t targetVarIndex = targetVariable->varIndex.value();
multi->statements.emplace_back(std::move(targetVariable));

// Step variable
std::optional<std::size_t> stepVarIndex;

if (forStatement.stepExpr)
{
auto stepVariable = ShaderBuilder::DeclareVariable("_nzsl_step", std::move(forStatement.stepExpr));
stepVariable->sourceLocation = forStatement.sourceLocation;
stepVariable->varIndex = m_context->nextVariableIndex++;

stepVarIndex = stepVariable->varIndex;
multi->statements.emplace_back(std::move(stepVariable));
}

// While
auto whileStatement = std::make_unique<WhileStatement>();
whileStatement->sourceLocation = forStatement.sourceLocation;
whileStatement->unroll = std::move(forStatement.unroll);

// While condition
auto conditionCounterVariable = ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation);
auto conditionTargetVariable = ShaderBuilder::Variable(targetVarIndex, counterType, forStatement.sourceLocation);

auto condition = ShaderBuilder::Binary(BinaryType::CompLt, std::move(conditionCounterVariable), std::move(conditionTargetVariable));
condition->sourceLocation = forStatement.sourceLocation;

whileStatement->condition = std::move(condition);

// While body
auto body = std::make_unique<MultiStatement>();
body->statements.reserve(2);

// Counter and increment
ExpressionPtr incrExpr;
if (stepVarIndex)
incrExpr = ShaderBuilder::Variable(*stepVarIndex, counterType, forStatement.sourceLocation);
else
incrExpr = (counterType == PrimitiveType::Int32) ? ShaderBuilder::ConstantValue(1, forStatement.sourceLocation) : ShaderBuilder::ConstantValue(1u, forStatement.sourceLocation);

incrExpr->sourceLocation = forStatement.sourceLocation;

auto incrCounter = ShaderBuilder::Assign(AssignType::CompoundAdd, ShaderBuilder::Variable(counterVarIndex, counterType, forStatement.sourceLocation), std::move(incrExpr));
incrCounter->sourceLocation = forStatement.sourceLocation;

body->statements.emplace_back(std::move(forStatement.statement));
body->statements.emplace_back(ShaderBuilder::ExpressionStatement(std::move(incrCounter)));

whileStatement->body = std::move(body);

multi->statements.emplace_back(std::move(whileStatement));

return ShaderBuilder::Scoped(std::move(multi));
}
}
Loading

0 comments on commit 7b608cd

Please sign in to comment.