Skip to content

Commit

Permalink
Make some changes on named externals
Browse files Browse the repository at this point in the history
  • Loading branch information
SirLynix committed Nov 30, 2024
1 parent 3c93f5b commit bbd6e78
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 81 deletions.
3 changes: 0 additions & 3 deletions include/NZSL/Ast/Compare.inl
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,6 @@ namespace nzsl::Ast
if (!Compare(lhs.variableId, rhs.variableId, params))
return false;

if (!Compare(lhs.prefix, rhs.prefix, params))
return false;

return true;
}

Expand Down
1 change: 0 additions & 1 deletion include/NZSL/Ast/Nodes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ namespace nzsl::Ast
void Visit(ExpressionVisitor& visitor) override;

std::size_t variableId;
std::string prefix;
};

struct NZSL_API UnaryExpression : Expression
Expand Down
4 changes: 2 additions & 2 deletions include/NZSL/Ast/SanitizeVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ namespace nzsl::Ast

std::size_t RegisterAlias(std::string name, std::optional<Identifier> aliasData, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
std::size_t RegisterConstant(std::string name, std::optional<ConstantValue> value, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
std::size_t RegisterExternalBlock(std::string name, std::size_t externalBlockIndex, const SourceLocation& sourceLocation);
std::size_t RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
std::size_t RegisterIntrinsic(std::string name, IntrinsicType type);
std::size_t RegisterModule(std::string moduleIdentifier, std::size_t moduleIndex);
void RegisterReservedName(std::string name);
void RegisterExternalName(std::string name, const SourceLocation& sourceLocation);
std::size_t RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
std::size_t RegisterType(std::string name, std::optional<ExpressionType> expressionType, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
std::size_t RegisterType(std::string name, std::optional<PartialType> partialType, std::optional<std::size_t> index, const SourceLocation& sourceLocation);
Expand Down Expand Up @@ -208,7 +208,7 @@ namespace nzsl::Ast
{
Alias,
Constant,
External,
ExternalBlock,
Function,
Intrinsic,
Module,
Expand Down
3 changes: 1 addition & 2 deletions src/NZSL/Ast/Cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ namespace nzsl::Ast
auto clone = std::make_unique<DeclareExternalStatement>();
clone->autoBinding = Clone(node.autoBinding);
clone->bindingSet = Clone(node.bindingSet);
clone->tag = node.tag;
clone->name = node.name;
clone->tag = node.tag;

clone->externalVars.reserve(node.externalVars.size());
for (const auto& var : node.externalVars)
Expand Down Expand Up @@ -594,7 +594,6 @@ namespace nzsl::Ast
{
auto clone = std::make_unique<VariableValueExpression>();
clone->variableId = node.variableId;
clone->prefix = node.prefix;

clone->cachedExpressionType = node.cachedExpressionType;
clone->sourceLocation = node.sourceLocation;
Expand Down
3 changes: 1 addition & 2 deletions src/NZSL/Ast/ReflectVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// For conditions of distribution and use, see copyright notice in Config.hpp

#include <NZSL/Ast/ReflectVisitor.hpp>
#include <stdexcept>

namespace nzsl::Ast
{
Expand Down Expand Up @@ -57,7 +56,7 @@ namespace nzsl::Ast
for (const auto& extVar : node.externalVars)
{
if (extVar.varIndex)
m_callbacks->onVariableIndex(node.name + extVar.name, *extVar.varIndex, extVar.sourceLocation);
m_callbacks->onVariableIndex(extVar.name, *extVar.varIndex, extVar.sourceLocation);
}
}

Expand Down
133 changes: 88 additions & 45 deletions src/NZSL/Ast/SanitizeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
#include <NZSL/Ast/ExportVisitor.hpp>
#include <NZSL/Ast/ExpressionType.hpp>
#include <NZSL/Ast/IndexRemapperVisitor.hpp>
#include <NZSL/Ast/RecursiveVisitor.hpp>
#include <NZSL/Ast/ReflectVisitor.hpp>
#include <NZSL/Ast/Utils.hpp>
#include <NZSL/Lang/Errors.hpp>
#include <NZSL/Lang/LangData.hpp>
#include <fmt/format.h>
#include <frozen/unordered_map.h>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -183,15 +181,21 @@ namespace nzsl::Ast
std::unique_ptr<DependencyCheckerVisitor> dependenciesVisitor;
};

struct NamedExternalBlockData
{
std::shared_ptr<Environment> environment;
};

struct UsedExternalData
{
unsigned int conditionalStatementIndex;
};

static constexpr std::size_t ModuleIdSentinel = std::numeric_limits<std::size_t>::max();
static constexpr std::size_t ModuleIdSentinel = std::numeric_limits<std::size_t>::max();

std::array<DeclareFunctionStatement*, ShaderStageTypeCount> entryFunctions = {};
std::vector<ModuleData> modules;
std::vector<NamedExternalBlockData> namedExternalBlocks;
std::vector<StatementPtr>* currentStatementList = nullptr;
std::unordered_map<std::string, std::size_t> moduleByName;
std::unordered_map<std::uint64_t, UsedExternalData> usedBindingIndexes;
Expand All @@ -205,6 +209,7 @@ namespace nzsl::Ast
IdentifierList<Identifier> aliases;
IdentifierList<IntrinsicType> intrinsics;
IdentifierList<std::size_t> moduleIndices;
IdentifierList<std::size_t> namedExternalBlockIndices;
IdentifierList<StructDescription*> structs;
IdentifierList<std::variant<ExpressionType, NamedPartialType>> types;
IdentifierList<ExpressionType> variableTypes;
Expand Down Expand Up @@ -318,29 +323,44 @@ namespace nzsl::Ast
auto& identifierExpr = static_cast<IdentifierExpression&>(*node.expr);
const IdentifierData* identifierData = FindIdentifier(identifierExpr.identifier);

if (identifierData && identifierData->category == IdentifierCategory::Module)
if (identifierData)
{
std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation);

const auto& env = *m_context->modules[moduleIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front().identifier);
if (identifierData)
switch (identifierData->category)
{
if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex)
return Cloner::Clone(node);
case IdentifierCategory::ExternalBlock:
{
std::size_t namedExternalBlockIndex = m_context->namedExternalBlockIndices.Retrieve(identifierData->index, node.sourceLocation);

return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation);
}
}
const auto& env = *m_context->namedExternalBlocks[namedExternalBlockIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front().identifier);
if (identifierData)
{
if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex)
return Cloner::Clone(node);

if (identifierData && identifierData->category == IdentifierCategory::External)
{
identifierData = FindIdentifier(identifierExpr.identifier + node.identifiers.front().identifier);
if (identifierData)
{
auto variableValuePtr = HandleIdentifier(identifierData, node.identifiers.front().sourceLocation);
static_cast<VariableValueExpression*>(variableValuePtr.get())->prefix = identifierExpr.identifier;
return variableValuePtr;
return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation);
}
break;
}

case IdentifierCategory::Module:
{
std::size_t moduleIndex = m_context->moduleIndices.Retrieve(identifierData->index, node.sourceLocation);

const auto& env = *m_context->modules[moduleIndex].environment;
identifierData = FindIdentifier(env, node.identifiers.front().identifier);
if (identifierData)
{
if (m_context->options.partialSanitization && identifierData->conditionalIndex != m_context->currentConditionalIndex)
return Cloner::Clone(node);

return HandleIdentifier(identifierData, node.identifiers.front().sourceLocation);
}
break;
}

default:
break;
}
}
}
Expand Down Expand Up @@ -1474,9 +1494,18 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized")
}
};

std::optional<std::size_t> namedExternalBlockIndex;
std::shared_ptr<Environment> previousEnv;
if (!clone->name.empty())
{
RegisterExternalName(clone->name, clone->sourceLocation);
namedExternalBlockIndex = m_context->namedExternalBlocks.size();
auto& namedExternal = m_context->namedExternalBlocks.emplace_back();
namedExternal.environment = std::make_shared<Environment>();

RegisterExternalBlock(clone->name, *namedExternalBlockIndex, clone->sourceLocation);

previousEnv = std::move(m_context->currentEnv);
m_context->currentEnv = namedExternal.environment;
}

bool hasUnresolved = false;
Expand All @@ -1485,25 +1514,30 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized")
{
auto& extVar = clone->externalVars[i];

SanitizeIdentifier(extVar.name, IdentifierScope::ExternalVariable);
std::string fullName;
if (!clone->name.empty())
{
fullName = fmt::format("{}_{}", clone->name, extVar.name);
SanitizeIdentifier(fullName, IdentifierScope::ExternalVariable);
}

std::string& internalName = (!clone->name.empty()) ? fullName : extVar.name;

std::string fullName = clone->name + extVar.name;

Context::UsedExternalData usedBindingData;
usedBindingData.conditionalStatementIndex = m_context->currentConditionalIndex;

if (auto it = m_context->declaredExternalVar.find(fullName); it != m_context->declaredExternalVar.end())
if (auto it = m_context->declaredExternalVar.find(internalName); it != m_context->declaredExternalVar.end())
{
if (it->second.conditionalStatementIndex == m_context->currentConditionalIndex || usedBindingData.conditionalStatementIndex == m_context->currentConditionalIndex)
throw CompilerExtAlreadyDeclaredError{ extVar.sourceLocation, extVar.name };
}

m_context->declaredExternalVar.emplace(fullName, usedBindingData);
m_context->declaredExternalVar.emplace(internalName, usedBindingData);

std::optional<ExpressionType> resolvedType = ResolveTypeExpr(extVar.type, false, node.sourceLocation);
if (!resolvedType.has_value())
{
RegisterUnresolved(fullName);
RegisterUnresolved(extVar.name);
hasUnresolved = true;
continue;
}
Expand Down Expand Up @@ -1574,9 +1608,13 @@ NAZARA_WARNING_GCC_DISABLE("-Wmaybe-uninitialized")
}

extVar.type = std::move(resolvedType).value();
extVar.varIndex = RegisterVariable(fullName, std::move(varType), extVar.varIndex, extVar.sourceLocation);
extVar.varIndex = RegisterVariable(extVar.name, std::move(varType), extVar.varIndex, extVar.sourceLocation);
SanitizeIdentifier(extVar.name, IdentifierScope::ExternalVariable);
}

if (previousEnv)
m_context->currentEnv = std::move(previousEnv);

// Resolve auto-binding entries when explicit binding are known
if (!hasUnresolved)
{
Expand Down Expand Up @@ -2886,7 +2924,7 @@ NAZARA_WARNING_POP()
return Clone(constantExpr); //< Turn ConstantExpression into ConstantValueExpression
}

case IdentifierCategory::External:
case IdentifierCategory::ExternalBlock:
throw AstUnexpectedIdentifierError{ sourceLocation, "external" };

case IdentifierCategory::Function:
Expand Down Expand Up @@ -3658,6 +3696,25 @@ NAZARA_WARNING_POP()
return constantIndex;
}

std::size_t SanitizeVisitor::RegisterExternalBlock(std::string name, std::size_t externalBlockIndex, const SourceLocation& sourceLocation)
{
if (!IsIdentifierAvailable(name))
throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name };

std::size_t index = m_context->namedExternalBlockIndices.Register(externalBlockIndex, std::nullopt, {});

m_context->currentEnv->identifiersInScope.push_back({
std::move(name),
{
index,
IdentifierCategory::ExternalBlock,
m_context->currentConditionalIndex
}
});

return index;
}

std::size_t SanitizeVisitor::RegisterFunction(std::string name, std::optional<FunctionData> funcData, std::optional<std::size_t> index, const SourceLocation& sourceLocation)
{
if (auto* identifier = FindIdentifier(name))
Expand Down Expand Up @@ -3760,20 +3817,6 @@ NAZARA_WARNING_POP()
});
}

void SanitizeVisitor::RegisterExternalName(std::string name, const SourceLocation& sourceLocation)
{
if (!IsIdentifierAvailable(name))
throw CompilerIdentifierAlreadyUsedError{ sourceLocation, name };

m_context->currentEnv->identifiersInScope.push_back({
std::move(name),
{
std::numeric_limits<std::size_t>::max(),
IdentifierCategory::External
}
});
}

std::size_t SanitizeVisitor::RegisterStruct(std::string name, std::optional<StructDescription*> description, std::optional<std::size_t> index, const SourceLocation& sourceLocation)
{
bool unresolved = false;
Expand Down
41 changes: 36 additions & 5 deletions src/NZSL/GlslWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include <NazaraUtils/CallOnExit.hpp>
#include <NazaraUtils/PathUtils.hpp>
#include <NZSL/Enums.hpp>
#include <NZSL/ShaderBuilder.hpp>
#include <NZSL/Ast/Cloner.hpp>
#include <NZSL/Ast/ConstantPropagationVisitor.hpp>
#include <NZSL/Ast/ConstantValue.hpp>
#include <NZSL/Ast/EliminateUnusedPassVisitor.hpp>
Expand All @@ -26,6 +24,7 @@
#include <sstream>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>

namespace nzsl
{
Expand Down Expand Up @@ -345,6 +344,7 @@ namespace nzsl
std::unordered_map<std::size_t, std::string> variableNames;
std::unordered_map<std::string, unsigned int> explicitTextureBinding;
std::unordered_map<std::string, unsigned int> explicitUniformBlockBinding;
std::unordered_set<std::string> reservedNames;
Nz::Bitset<> declaredFunctions;
const GlslWriter::BindingMapping& bindingMapping;
GlslWriterPreVisitor previsitor;
Expand Down Expand Up @@ -2267,7 +2267,24 @@ namespace nzsl
AppendComment("struct tag: " + structInfo.desc->tag);
}

std::string varName = node.name + externalVar.name + m_currentState->moduleSuffix;
std::string varName = externalVar.name + m_currentState->moduleSuffix;
if (!node.name.empty())
varName = fmt::format("{}_{}", node.name, varName);

if (m_currentState->reservedNames.count(varName) > 0)
{
unsigned int cloneIndex = 2;
std::string candidateName;
do
{
candidateName = fmt::format("{}_{}", varName, cloneIndex++);
}
while (m_currentState->reservedNames.count(candidateName) > 0);

varName = std::move(candidateName);
}

m_currentState->reservedNames.insert(varName);

// Layout handling
bool hasLayout = false;
Expand Down Expand Up @@ -2519,9 +2536,23 @@ namespace nzsl
void GlslWriter::Visit(Ast::DeclareVariableStatement& node)
{
assert(node.varIndex);
RegisterVariable(*node.varIndex, node.varName);

AppendVariableDeclaration(node.varType.GetResultingValue(), node.varName);
std::string varName = node.varName;
if (m_currentState->reservedNames.count(varName) > 0)
{
unsigned int cloneIndex = 2;
std::string candidateName;
do
{
candidateName = fmt::format("{}_{}", varName, cloneIndex++);
} while (m_currentState->reservedNames.count(candidateName) > 0);

varName = std::move(candidateName);
}

AppendVariableDeclaration(node.varType.GetResultingValue(), varName);
RegisterVariable(*node.varIndex, std::move(varName));

if (node.initialExpression)
{
Append(" = ");
Expand Down
Loading

0 comments on commit bbd6e78

Please sign in to comment.